<h1> Encrypted Inference-Linear Regression</h1>

In this tutorial, we train a Linear regression model in plaintext on Boston Housing Dataset. Then we use the model for performing encrypted and plaintext inference on test data. This tutorial uses protocol Falcon for 3 parties and SPDZ for 3 and 5 parties. It depicts how you can perform encrypted inference on test data with nearly the same accuracy as plaintext. 

In [1]:
import pandas as pd
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.utils.data as data_utils

In [3]:
#Set a manual seed to maintain consistency
torch.manual_seed(0)

<torch._C.Generator at 0x7fabf4c432b0>

<h2>Data Loading and Processing</h1>

In [4]:
!wget https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data

--2021-07-13 19:26:30--  https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49082 (48K) [application/x-httpd-php]
Saving to: ‘housing.data.1’


2021-07-13 19:26:32 (70.2 KB/s) - ‘housing.data.1’ saved [49082/49082]



In [5]:
#Import dataset
dataset=pd.read_csv("housing.data",delim_whitespace=True,
                    names=["crim","zn","indus",
                           "chas","nox","rm",
                           "age","dis","rad",
                           "tax","ptratio","black",
                           "lstat","medv"])

In [6]:
#Visualize and look at columns and rows of dataset
dataset.head()

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat,medv
0,0.00632,18.0,2.31,0,0.538,6.575,65.2,4.09,1,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0,0.469,6.421,78.9,4.9671,2,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0,0.469,7.185,61.1,4.9671,2,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0,0.458,6.998,45.8,6.0622,3,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0,0.458,7.147,54.2,6.0622,3,222.0,18.7,396.9,5.33,36.2


In [7]:
X_data = dataset.drop("medv",axis=1)
y_data = dataset["medv"]

In [8]:
X_data = X_data.apply(
    lambda x: (x - x.mean()) / x.std()
)

In [9]:
y_data

0      24.0
1      21.6
2      34.7
3      33.4
4      36.2
       ... 
501    22.4
502    20.6
503    23.9
504    22.0
505    11.9
Name: medv, Length: 506, dtype: float64

In [10]:
X_data

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat
0,-0.419367,0.284548,-1.286636,-0.272329,-0.144075,0.413263,-0.119895,0.140075,-0.981871,-0.665949,-1.457558,0.440616,-1.074499
1,-0.416927,-0.487240,-0.592794,-0.272329,-0.739530,0.194082,0.366803,0.556609,-0.867024,-0.986353,-0.302794,0.440616,-0.491953
2,-0.416929,-0.487240,-0.592794,-0.272329,-0.739530,1.281446,-0.265549,0.556609,-0.867024,-0.986353,-0.302794,0.396035,-1.207532
3,-0.416338,-0.487240,-1.305586,-0.272329,-0.834458,1.015298,-0.809088,1.076671,-0.752178,-1.105022,0.112920,0.415751,-1.360171
4,-0.412074,-0.487240,-1.305586,-0.272329,-0.834458,1.227362,-0.510674,1.076671,-0.752178,-1.105022,0.112920,0.440616,-1.025487
...,...,...,...,...,...,...,...,...,...,...,...,...,...
501,-0.412820,-0.487240,0.115624,-0.272329,0.157968,0.438881,0.018654,-0.625178,-0.981871,-0.802418,1.175303,0.386834,-0.417734
502,-0.414839,-0.487240,0.115624,-0.272329,0.157968,-0.234316,0.288648,-0.715931,-0.981871,-0.802418,1.175303,0.440616,-0.500355
503,-0.413038,-0.487240,0.115624,-0.272329,0.157968,0.983986,0.796661,-0.772919,-0.981871,-0.802418,1.175303,0.440616,-0.982076
504,-0.407361,-0.487240,0.115624,-0.272329,0.157968,0.724955,0.736268,-0.667776,-0.981871,-0.802418,1.175303,0.402826,-0.864446


In [11]:
features = torch.tensor(X_data.values.astype(np.float16).astype(np.float32)) 
targets = torch.tensor(y_data.values.astype(np.float16).astype(np.float32))

In [12]:
features

tensor([[-0.4194,  0.2847, -1.2871,  ..., -1.4580,  0.4407, -1.0742],
        [-0.4170, -0.4873, -0.5928,  ..., -0.3027,  0.4407, -0.4919],
        [-0.4170, -0.4873, -0.5928,  ..., -0.3027,  0.3960, -1.2080],
        ...,
        [-0.4131, -0.4873,  0.1156,  ...,  1.1758,  0.4407, -0.9819],
        [-0.4075, -0.4873,  0.1156,  ...,  1.1758,  0.4028, -0.8643],
        [-0.4146, -0.4873,  0.1156,  ...,  1.1758,  0.4407, -0.6685]])

In [13]:
# Arguments of projects
batch_size = 16
epochs = 500
train_test_split = 0.8
lr = 0.001

In [14]:
train_x = features[:int(len(features)*train_test_split)]
train_y = targets[:int(len(features)*train_test_split)]

test_x = features[int(len(features)*train_test_split)+1:]
test_y = targets[int(len(features)*train_test_split)+1:]

In [15]:
def get_batches(X,y):
    batches = []
    for index in range(0,len(train_x)+1,batch_size):
        batches.append((X[index:index+batch_size],y[index:index+batch_size]))
    
    return batches

In [16]:
train_batches=get_batches(train_x,train_y)

<h1>Plaintext Training</h1>

In [17]:
#Import syft
import syft as sy

In [18]:
#Define Linear regression model
class LinearSyNet(sy.Module):
    def __init__(self, torch_ref):
        super(LinearSyNet, self).__init__(torch_ref=torch_ref)
        self.fc1 = self.torch_ref.nn.Linear(13,1)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [19]:
#Define model, loss function and optimizer
model = LinearSyNet(torch)
criterion = torch.nn.MSELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [20]:
#Training Loop
for epoch in range(epochs):
  running_loss = 0.0
  for index in range(0,len(train_batches)):
    # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
    optimizer.zero_grad()

    # get output from the model, given the inputs
    outputs = model(train_batches[index][0]).reshape([-1])

    # get loss for the predicted output
    loss = criterion(outputs,train_batches[index][1])
    running_loss += loss
    # get gradients w.r.t to parameters
    loss.backward()

    # update parameters
    optimizer.step()
    
  test_accuracy = criterion(model(test_x).reshape([-1]),test_y)
  print(f"Epoch {epoch}/{epochs}  Running Loss : {running_loss.item()/batch_size} and test loss: {test_accuracy.item()}")

Epoch 0/500  Running Loss : 978.0725708007812 and test loss: 367.93927001953125
Epoch 1/500  Running Loss : 831.0911865234375 and test loss: 411.168212890625
Epoch 2/500  Running Loss : 720.55322265625 and test loss: 440.28497314453125
Epoch 3/500  Running Loss : 633.6892700195312 and test loss: 457.1828308105469
Epoch 4/500  Running Loss : 562.962646484375 and test loss: 464.29931640625
Epoch 5/500  Running Loss : 503.80291748046875 and test loss: 463.9801330566406
Epoch 6/500  Running Loss : 453.3403015136719 and test loss: 458.2357177734375
Epoch 7/500  Running Loss : 409.6958312988281 and test loss: 448.6855773925781
Epoch 8/500  Running Loss : 371.5797424316406 and test loss: 436.5837707519531
Epoch 9/500  Running Loss : 338.0628967285156 and test loss: 422.87548828125
Epoch 10/500  Running Loss : 308.4455261230469 and test loss: 408.2576904296875
Epoch 11/500  Running Loss : 282.1790466308594 and test loss: 393.2343444824219
Epoch 12/500  Running Loss : 258.81988525390625 and tes

Epoch 109/500  Running Loss : 42.54023361206055 and test loss: 43.331031799316406
Epoch 110/500  Running Loss : 42.491207122802734 and test loss: 42.86627960205078
Epoch 111/500  Running Loss : 42.44314956665039 and test loss: 42.410919189453125
Epoch 112/500  Running Loss : 42.39602279663086 and test loss: 41.96467590332031
Epoch 113/500  Running Loss : 42.34979248046875 and test loss: 41.527347564697266
Epoch 114/500  Running Loss : 42.304443359375 and test loss: 41.09870147705078
Epoch 115/500  Running Loss : 42.259925842285156 and test loss: 40.67849349975586
Epoch 116/500  Running Loss : 42.21622848510742 and test loss: 40.26654815673828
Epoch 117/500  Running Loss : 42.17331314086914 and test loss: 39.86261749267578
Epoch 118/500  Running Loss : 42.131160736083984 and test loss: 39.466590881347656
Epoch 119/500  Running Loss : 42.089759826660156 and test loss: 39.0782356262207
Epoch 120/500  Running Loss : 42.04906463623047 and test loss: 38.69734573364258
Epoch 121/500  Running 

Epoch 220/500  Running Loss : 39.87311553955078 and test loss: 21.747699737548828
Epoch 221/500  Running Loss : 39.861854553222656 and test loss: 21.6955623626709
Epoch 222/500  Running Loss : 39.850711822509766 and test loss: 21.644638061523438
Epoch 223/500  Running Loss : 39.83968734741211 and test loss: 21.594924926757812
Epoch 224/500  Running Loss : 39.82878112792969 and test loss: 21.546396255493164
Epoch 225/500  Running Loss : 39.817996978759766 and test loss: 21.49904441833496
Epoch 226/500  Running Loss : 39.80730438232422 and test loss: 21.45285415649414
Epoch 227/500  Running Loss : 39.79674530029297 and test loss: 21.407787322998047
Epoch 228/500  Running Loss : 39.786285400390625 and test loss: 21.363840103149414
Epoch 229/500  Running Loss : 39.77593231201172 and test loss: 21.32098960876465
Epoch 230/500  Running Loss : 39.76569366455078 and test loss: 21.279216766357422
Epoch 231/500  Running Loss : 39.75556182861328 and test loss: 21.238508224487305
Epoch 232/500  Ru

Epoch 332/500  Running Loss : 39.112548828125 and test loss: 20.458087921142578
Epoch 333/500  Running Loss : 39.10874557495117 and test loss: 20.46868324279785
Epoch 334/500  Running Loss : 39.104976654052734 and test loss: 20.47944450378418
Epoch 335/500  Running Loss : 39.10124206542969 and test loss: 20.490375518798828
Epoch 336/500  Running Loss : 39.09754180908203 and test loss: 20.501455307006836
Epoch 337/500  Running Loss : 39.09386444091797 and test loss: 20.5126895904541
Epoch 338/500  Running Loss : 39.090232849121094 and test loss: 20.524080276489258
Epoch 339/500  Running Loss : 39.08661651611328 and test loss: 20.535627365112305
Epoch 340/500  Running Loss : 39.08304977416992 and test loss: 20.547306060791016
Epoch 341/500  Running Loss : 39.07950210571289 and test loss: 20.55913543701172
Epoch 342/500  Running Loss : 39.07598114013672 and test loss: 20.57109832763672
Epoch 343/500  Running Loss : 39.0724983215332 and test loss: 20.583200454711914
Epoch 344/500  Running 

Epoch 445/500  Running Loss : 38.83354187011719 and test loss: 22.129133224487305
Epoch 446/500  Running Loss : 38.8320198059082 and test loss: 22.144758224487305
Epoch 447/500  Running Loss : 38.83050537109375 and test loss: 22.160348892211914
Epoch 448/500  Running Loss : 38.829010009765625 and test loss: 22.17591094970703
Epoch 449/500  Running Loss : 38.82752227783203 and test loss: 22.191442489624023
Epoch 450/500  Running Loss : 38.82603454589844 and test loss: 22.20697021484375
Epoch 451/500  Running Loss : 38.824581146240234 and test loss: 22.222454071044922
Epoch 452/500  Running Loss : 38.823116302490234 and test loss: 22.237926483154297
Epoch 453/500  Running Loss : 38.8216667175293 and test loss: 22.253389358520508
Epoch 454/500  Running Loss : 38.82023239135742 and test loss: 22.268815994262695
Epoch 455/500  Running Loss : 38.818809509277344 and test loss: 22.284194946289062
Epoch 456/500  Running Loss : 38.8173942565918 and test loss: 22.299558639526367
Epoch 457/500  Ru

<h1>Plaintext Inference</h1>

In [21]:
plaintext_predictions = model(test_x).reshape([-1])

In [22]:
print("MSE Loss: ",criterion(plaintext_predictions,test_y).item())

MSE Loss:  22.930004119873047


<h1>Encrypted Inference</h1>

In [23]:
#Syft and SyMPC imports required for encrypted inference
import syft as sy
from sympc.module.nn import mse_loss
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.config import Config
from sympc.protocol import Falcon,FSS
import time

In [24]:
def get_clients(n_parties):
  #Generate required number of syft clients and return them.

  parties=[]
  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())

  return parties

In [25]:
def inference(n_clients,protocol):
    
  # Get VM clients 
  parties=get_clients(n_clients)

  # Setup the session for the computation
  session = Session(parties = parties,protocol = protocol)
  SessionManager.setup_mpc(session)

  #Encrypt model 
  mpc_model = model.share(session)

  #Encrypt test data
  test_data=MPCTensor(secret=test_x, session = session)

  #Perform inference and measure time taken
  start_time = time.time()
  enc_results = mpc_model(test_data)
  end_time = time.time()

  print(f"Time for inference: {end_time-start_time}s")

  #Get plaintext predictions
  predictions = enc_results.reconstruct()
  
  #Calculate Loss
  print("MSE Loss: ",criterion(predictions.reshape([-1]),test_y).item())
    
  return predictions

In [26]:
predictions=inference(3,Falcon("semi-honest"))

Time for inference: 0.03281688690185547s
MSE Loss:  22.929933547973633


We can see that the prediction values and mean squared error values are almost the same as final model. Small differences are due to precision loss.

In [27]:
for index in range(0,10):
    print(f"Index {index}")
    print(f"Encrypted Prediction Output {predictions[index].item()}")
    print(f"Plaintext Prediction Output {plaintext_predictions[index].item()}")
    print(f"Expected Prediction: {test_y[index]}")
    print("\n")

Index 0
Encrypted Prediction Output 1.840240478515625
Plaintext Prediction Output 1.8402048349380493
Expected Prediction: 5.0


Index 1
Encrypted Prediction Output 5.1915740966796875
Plaintext Prediction Output 5.191580772399902
Expected Prediction: 11.8984375


Index 2
Encrypted Prediction Output 19.7861328125
Plaintext Prediction Output 19.786155700683594
Expected Prediction: 27.90625


Index 3
Encrypted Prediction Output 13.1734619140625
Plaintext Prediction Output 13.173490524291992
Expected Prediction: 17.203125


Index 4
Encrypted Prediction Output 20.987045288085938
Plaintext Prediction Output 20.987075805664062
Expected Prediction: 27.5


Index 5
Encrypted Prediction Output 14.210540771484375
Plaintext Prediction Output 14.210532188415527
Expected Prediction: 15.0


Index 6
Encrypted Prediction Output 19.272994995117188
Plaintext Prediction Output 19.273014068603516
Expected Prediction: 17.203125


Index 7
Encrypted Prediction Output 1.7180023193359375
Plaintext Prediction Outp

<h1> Conclusion </h1>

Falcon can also work with malicious security guarantee but at a large inference time. 

In [28]:
predictions=inference(3,Falcon("malicious"))

Time for inference: 0.5620832443237305s
MSE Loss:  22.929933547973633


In [29]:
predictions=inference(3,FSS())

[2021-07-13T19:26:39.657868+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: fbaddfc859fb4a36ace0f4f3fdf82efa>.
[2021-07-13T19:26:39.660616+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: b087878d1e9542369b7087da48075252>.
[2021-07-13T19:26:39.663202+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: f150e136b5a94831b5eb8b296f86b40b>.


Time for inference: 0.5415430068969727s
MSE Loss:  22.929967880249023


In [30]:
predictions=inference(5,FSS())

[2021-07-13T19:26:41.369789+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: ec3046e80a7149c9830ee8d73af69433>.
[2021-07-13T19:26:41.374661+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 5a617c91666648e298e42a1c964fe0a7>.
[2021-07-13T19:26:41.376340+0530][CRITICAL][logger]][32938] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 3570c369a5ba44c7be3b2b1f0456af00>.


Time for inference: 0.9285099506378174s
MSE Loss:  22.929994583129883



| Protocol | Security Type| Parties | Inference Time (s) |
| --- | --- | --- | --- |
| Falcon | Semi-honest | 3 | 0.03391 |
| Falcon | Malicious | 3 | 0.61146 |
| FSS| Semi-honest | 3 | 0.56047 |
| FSS | Semi-honest | 5 | 0.91995|

Falcon provides fast inference for 3 parties in semi-honest setting. While, Functional Secret Sharing (FSS) allows inference for N number of parties. Both allow inference with reasonable accuracy. 

<h3>What's next?</h3>

SyMPC is still under development! We will add here more features as soon they are stable enough, stay tuned! 🕺

If you enjoyed this tutorial, show your support by Starring SyMPC! 🙏