In [1]:
import syft as sy
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.protocol import Falcon,FSS
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import os
from typing import Any, Dict, List
import copy
import random
import time
import numpy as np
from tqdm import tqdm
import gc

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

batch_size =200
epochs = 55
train_test_split= 0.8
lr = 0.001
n_client = 70
chosen_prob = 0.6
local_batch_size = 32
local_epochs = 10


class Block(nn.Module):
  def __init__(self, inchannel, outchannel, res=True,stride=0):
    super(Block, self).__init__()
    self.res = res     # With or without residual connection
    self.left = nn.Sequential(
        nn.Conv2d(inchannel, outchannel, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(outchannel),
        nn.ReLU(inplace=True),
        nn.Conv2d(outchannel, outchannel, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(outchannel),
    )
    if stride != 1 or inchannel != outchannel:
        self.shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=1, bias=False),
            nn.BatchNorm2d(outchannel),
        )
    else:
        self.shortcut = nn.Sequential()

    self.relu = nn.Sequential(
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    out = self.left(x)
    if self.res:
        out += self.shortcut(x)
    out = self.relu(out)
    return out


class Resnet(nn.Module):
  def __init__(self, cfg=[64, 'M', 128,  'M', 256, 'M', 512, 'M'], res=True):
      super(Resnet, self).__init__()
      self.res = res       # With or without residual connection
      self.cfg = cfg       # Configuration list
      self.inchannel = 3   # Number of initial input channels
      self.futures = self.make_layer()
      # The full connection layer and classifier after the construction of the convolution layer:
      self.classifier = nn.Sequential(nn.Dropout(0.4),            
          nn.Linear(4 * 512, 10), )   # fc

  def make_layer(self):
    layers = []
    for v in self.cfg:
      if v == 'M':
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
      else:
        layers.append(Block(self.inchannel, v, self.res))
        self.inchannel = v    # Change the number of input channels to the number of output channels of the previous layer
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.futures(x)
    # view(out.size(0), -1): change tensor size from (N ,H , W) to (N, H*W)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out


class ResNet18(sy.Module):
  def __init__(self, torch_ref):
    super(ResNet18, self).__init__(torch_ref=torch_ref)
    self.model=Resnet()
  def forward(self, x):
    x = self.model.forward(x.reconstruct().cuda())
    return x

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='cifar_data', train=True,
                  download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='cifar_data',
                  train=False,
                  download=True,
                  transform=transform)

trainloader = torch.utils.data.DataLoader(trainset,
                     batch_size=batch_size,
                     shuffle=True)
testloader = torch.utils.data.DataLoader(testset,
                    batch_size=batch_size,
                    shuffle=True)



# Get the specified number of clients
def get_clients(n_clients):
  clients = []
  for index in range(n_clients):
      clients.append(sy.VirtualMachine(name="worker" + str(index)).get_root_client())
  return clients

# Divide the data set according to the number of clients
def split_send(data, session):
  data_pointers = []
  split_size = int(len(data) / len(session.parties)) + 1
  for index in range(0, len(session.parties)):
      ptr = data[index * split_size:index * split_size + split_size].share(session=session)
      data_pointers.append(ptr)

  return data_pointers

def train_model(n_clients,dataloader, protocol=None):

  # Get client
  parties = get_clients(n_clients)
  print(parties)
  # Initialize the connection pool according to the number of clients
  # if(protocol):
  #   session = Session(parties=parties, protocol=protocol)
  # else:
  #   session = Session(parties=parties)
  # SessionManager.setup_mpc(session)
  for epoch in tqdm(range(epochs)):
    running_loss = 0.0
    total_correct = 0
    if(protocol):
      session = Session(parties=parties, protocol=protocol)
    else:
      session = Session(parties=parties)
    SessionManager.setup_mpc(session)
    optimizer.zero_grad()
    count=0
    for index, (feature, label) in enumerate(dataloader):
      
      # Split data and send to session
      pointers = split_send(feature, session)
      # Model encryption
      mpc_model = model.share(session)
      results = []
      # The data set is encrypted and calculated using MPC
      for ptr in pointers:
        smpc_results = mpc_model(ptr)
        results.append(smpc_results)
      predictions = torch.cat(results)
      # Loss caculation
      loss = criterion(predictions, label.cuda())
      loss.backward()
      optimizer.step()
      running_loss += np.round(loss.item(),4)
      y_pred_decode = torch.argmax(predictions, dim=1)
      total_correct += y_pred_decode.eq(label.cuda()).sum().item()
      if index >=100:
        break
      
    acc = total_correct/((index+1)*batch_size)
    print(f"Epoch {epoch}/{epochs}  Loss:{running_loss/batch_size} ,accuracy:{acc}")
    torch.cuda.empty_cache()
    test_feature=None
    test_label=None
    
    
def test_model(n_clients,dataloader, protocol=None):
  parties = get_clients(n_clients)
  print(parties)
  for epoch in tqdm(range(1)):
    total_correct = 0
    if(protocol):
      session = Session(parties=parties, protocol=protocol)
    else:
      session = Session(parties=parties)
    SessionManager.setup_mpc(session)
    optimizer.zero_grad()
    for _, (feature, label) in enumerate(dataloader):
      
      # Split data and send to session
      pointers = split_send(feature, session)
      # Model encryption
      mpc_model = model.share(session)
      results = []
      # The data set is encrypted and calculated using MPC
      for ptr in pointers:
        smpc_results = mpc_model(ptr)
        results.append(smpc_results)
      predictions = torch.cat(results)
      y_pred_decode = torch.argmax(predictions, dim=1)
      total_correct += y_pred_decode.eq(label.cuda()).sum().item()
    acc = total_correct/10000
    print(f"test accuracy:{acc}")
    torch.cuda.empty_cache()
    test_feature=None
    test_label=None

model = ResNet18(torch)
model.model.to(device)
sympc.module.SKIP_LAYERS_NAME={"Flatten","Resnet","Small_Model"}
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
print('------------------train------------------------')
train_model(3,trainloader,Falcon("semi-honest"))

------------------train------------------------
[<VirtualMachineClient: worker0 Client>, <VirtualMachineClient: worker1 Client>, <VirtualMachineClient: worker2 Client>]


  2%|▏         | 1/55 [01:40<1:30:38, 100.71s/it]

Epoch 0/55  Loss:1.0267425000000001 ,accuracy:0.3613861386138614


  4%|▎         | 2/55 [03:22<1:29:26, 101.25s/it]

Epoch 1/55  Loss:0.9341705000000001 ,accuracy:0.4226237623762376


  5%|▌         | 3/55 [05:04<1:28:02, 101.60s/it]

Epoch 2/55  Loss:0.732844 ,accuracy:0.5012376237623762


  7%|▋         | 4/55 [06:45<1:26:09, 101.36s/it]

Epoch 3/55  Loss:0.6641140000000001 ,accuracy:0.5482673267326733


  9%|▉         | 5/55 [08:27<1:24:37, 101.56s/it]

Epoch 4/55  Loss:0.6034349999999999 ,accuracy:0.5886138613861386


 11%|█         | 6/55 [10:08<1:22:58, 101.61s/it]

Epoch 5/55  Loss:0.5320875 ,accuracy:0.6342079207920792


 13%|█▎        | 7/55 [11:50<1:21:10, 101.47s/it]

Epoch 6/55  Loss:0.47034099999999995 ,accuracy:0.6736633663366337


 15%|█▍        | 8/55 [13:32<1:19:36, 101.62s/it]

Epoch 7/55  Loss:0.45711799999999997 ,accuracy:0.6813861386138614


 16%|█▋        | 9/55 [15:13<1:17:49, 101.52s/it]

Epoch 8/55  Loss:0.4503205000000002 ,accuracy:0.6883168316831683


 18%|█▊        | 10/55 [16:55<1:16:14, 101.65s/it]

Epoch 9/55  Loss:0.42638999999999994 ,accuracy:0.707079207920792


 20%|██        | 11/55 [18:36<1:14:26, 101.52s/it]

Epoch 10/55  Loss:0.39685550000000014 ,accuracy:0.7258415841584158


 22%|██▏       | 12/55 [20:18<1:12:49, 101.62s/it]

Epoch 11/55  Loss:0.3799985 ,accuracy:0.7380693069306931


 24%|██▎       | 13/55 [21:59<1:11:07, 101.61s/it]

Epoch 12/55  Loss:0.3607125 ,accuracy:0.7561386138613861


 25%|██▌       | 14/55 [23:41<1:09:24, 101.58s/it]

Epoch 13/55  Loss:0.31869700000000006 ,accuracy:0.7800495049504951


 27%|██▋       | 15/55 [25:22<1:07:39, 101.49s/it]

Epoch 14/55  Loss:0.3026965 ,accuracy:0.7894554455445545


 29%|██▉       | 16/55 [27:04<1:06:03, 101.62s/it]

Epoch 15/55  Loss:0.2984265 ,accuracy:0.7948019801980198


 31%|███       | 17/55 [28:46<1:04:25, 101.71s/it]

Epoch 16/55  Loss:0.2845830000000001 ,accuracy:0.8035148514851486


 33%|███▎      | 18/55 [30:28<1:02:39, 101.62s/it]

Epoch 17/55  Loss:0.28020450000000013 ,accuracy:0.8061881188118812


 35%|███▍      | 19/55 [32:10<1:01:03, 101.77s/it]

Epoch 18/55  Loss:0.2717375000000001 ,accuracy:0.8154455445544554


 36%|███▋      | 20/55 [33:51<59:20, 101.72s/it]  

Epoch 19/55  Loss:0.24904849999999995 ,accuracy:0.8299504950495049


 38%|███▊      | 21/55 [35:33<57:38, 101.72s/it]

Epoch 20/55  Loss:0.23612949999999994 ,accuracy:0.8386633663366336


 40%|████      | 22/55 [37:15<56:04, 101.96s/it]

Epoch 21/55  Loss:0.2278105 ,accuracy:0.8485148514851485


 42%|████▏     | 23/55 [38:58<54:23, 101.98s/it]

Epoch 22/55  Loss:0.21955 ,accuracy:0.8511386138613861


 44%|████▎     | 24/55 [40:40<52:44, 102.07s/it]

Epoch 23/55  Loss:0.20749750000000003 ,accuracy:0.8586633663366336


 45%|████▌     | 25/55 [42:22<51:04, 102.15s/it]

Epoch 24/55  Loss:0.202858 ,accuracy:0.8598514851485148


 47%|████▋     | 26/55 [44:05<49:31, 102.47s/it]

Epoch 25/55  Loss:0.19686350000000008 ,accuracy:0.8638613861386139


 49%|████▉     | 27/55 [45:47<47:43, 102.26s/it]

Epoch 26/55  Loss:0.17057150000000004 ,accuracy:0.8813366336633663


 51%|█████     | 28/55 [47:29<46:00, 102.26s/it]

Epoch 27/55  Loss:0.17037000000000005 ,accuracy:0.8813861386138614


 53%|█████▎    | 29/55 [49:12<44:23, 102.44s/it]

Epoch 28/55  Loss:0.16232049999999995 ,accuracy:0.8866831683168317


 55%|█████▍    | 30/55 [50:55<42:42, 102.49s/it]

Epoch 29/55  Loss:0.1489425 ,accuracy:0.8968811881188119


 56%|█████▋    | 31/55 [52:37<40:58, 102.43s/it]

Epoch 30/55  Loss:0.13871700000000003 ,accuracy:0.9053465346534654


 58%|█████▊    | 32/55 [54:19<39:15, 102.40s/it]

Epoch 31/55  Loss:0.14268750000000002 ,accuracy:0.9011881188118812


 60%|██████    | 33/55 [56:02<37:34, 102.48s/it]

Epoch 32/55  Loss:0.13958399999999999 ,accuracy:0.901930693069307


 62%|██████▏   | 34/55 [57:45<35:54, 102.58s/it]

Epoch 33/55  Loss:0.12084699999999997 ,accuracy:0.915990099009901


 64%|██████▎   | 35/55 [59:28<34:12, 102.64s/it]

Epoch 34/55  Loss:0.11803149999999998 ,accuracy:0.9182673267326733


 65%|██████▌   | 36/55 [1:01:10<32:30, 102.64s/it]

Epoch 35/55  Loss:0.11435750000000003 ,accuracy:0.921980198019802


 67%|██████▋   | 37/55 [1:02:53<30:49, 102.74s/it]

Epoch 36/55  Loss:0.10547100000000002 ,accuracy:0.9257920792079208


 69%|██████▉   | 38/55 [1:04:37<29:09, 102.92s/it]

Epoch 37/55  Loss:0.10843150000000001 ,accuracy:0.9244059405940594


 71%|███████   | 39/55 [1:06:20<27:30, 103.14s/it]

Epoch 38/55  Loss:0.10676400000000001 ,accuracy:0.9271287128712872


 73%|███████▎  | 40/55 [1:08:04<25:50, 103.34s/it]

Epoch 39/55  Loss:0.09495150000000001 ,accuracy:0.9356930693069307


 75%|███████▍  | 41/55 [1:09:48<24:07, 103.38s/it]

Epoch 40/55  Loss:0.09853400000000001 ,accuracy:0.9300990099009901


 76%|███████▋  | 42/55 [1:11:31<22:24, 103.45s/it]

Epoch 41/55  Loss:0.10924500000000002 ,accuracy:0.9248019801980198


 78%|███████▊  | 43/55 [1:13:16<20:45, 103.76s/it]

Epoch 42/55  Loss:0.101325 ,accuracy:0.9296039603960397


 80%|████████  | 44/55 [1:15:00<19:02, 103.91s/it]

Epoch 43/55  Loss:0.10106649999999999 ,accuracy:0.9308415841584159


 82%|████████▏ | 45/55 [1:16:44<17:19, 103.91s/it]

Epoch 44/55  Loss:0.07670499999999998 ,accuracy:0.9466831683168316


 84%|████████▎ | 46/55 [1:18:28<15:35, 103.99s/it]

Epoch 45/55  Loss:0.0775865 ,accuracy:0.9471287128712871


 85%|████████▌ | 47/55 [1:20:12<13:52, 104.04s/it]

Epoch 46/55  Loss:0.07112000000000003 ,accuracy:0.9512376237623762


 87%|████████▋ | 48/55 [1:21:57<12:09, 104.24s/it]

Epoch 47/55  Loss:0.08510349999999998 ,accuracy:0.940049504950495


 89%|████████▉ | 49/55 [1:23:42<10:26, 104.38s/it]

Epoch 48/55  Loss:0.07092700000000005 ,accuracy:0.9496039603960396


 91%|█████████ | 50/55 [1:25:26<08:42, 104.43s/it]

Epoch 49/55  Loss:0.06295249999999998 ,accuracy:0.958019801980198


 93%|█████████▎| 51/55 [1:27:11<06:58, 104.68s/it]

Epoch 50/55  Loss:0.06025899999999998 ,accuracy:0.9583168316831683


 95%|█████████▍| 52/55 [1:28:56<05:14, 104.69s/it]

Epoch 51/55  Loss:0.06395100000000004 ,accuracy:0.9545544554455445


 96%|█████████▋| 53/55 [1:30:41<03:29, 104.75s/it]

Epoch 52/55  Loss:0.0614275 ,accuracy:0.958069306930693


 98%|█████████▊| 54/55 [1:32:25<01:44, 104.64s/it]

Epoch 53/55  Loss:0.04560200000000001 ,accuracy:0.9673762376237623


100%|██████████| 55/55 [1:34:11<00:00, 102.75s/it]

Epoch 54/55  Loss:0.0407055 ,accuracy:0.972029702970297





In [4]:
print('------------------test------------------------')
test_model(3,testloader,Falcon("semi-honest"))

------------------test------------------------
[<VirtualMachineClient: worker0 Client>, <VirtualMachineClient: worker1 Client>, <VirtualMachineClient: worker2 Client>]


100%|██████████| 1/1 [00:52<00:00, 52.03s/it]

test accuracy:0.8128



