In [1]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [2]:
import crypten
import torch

crypten.init()
torch.set_num_threads(1)

#ignore warnings
import warnings; 
warnings.filterwarnings("ignore")

INFO:root:DistributedCommunicator with rank 0
INFO:root:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:3 to store for rank: 0
INFO:root:World size = 1


In [3]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
import itertools

from torchvision import transforms, datasets, models

from PIL import Image
import numpy as np
import pandas as pd

In [4]:
class MiniONN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1, 1)
        self.conv5 = nn.Conv2d(64, 64, 3, 1, 1)
        
        self.conv6 = nn.Conv2d(64, 64, 1, 1, 0)
        self.conv7 = nn.Conv2d(64, 16, 1, 1, 0)
        
        self.fc = nn.Linear(1024, 1)
        
        self.avg1 = nn.AvgPool2d(2, 2)
        self.avg2 = nn.AvgPool2d(2, 2)
        
        
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        
        h = self.avg1(h)
        
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        
        h = self.avg2(h)
        
        h = F.relu(self.conv5(h))
        h = F.relu(self.conv6(h))
        h = F.relu(self.conv7(h))
        
        h = h.view(-1, 1024)
        h = self.fc(h)
        
        return h
        
model_ft = MiniONN()

In [5]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn=nn*s
        pp+=nn
    return pp

get_n_params(model_ft)

155729

In [6]:
ALICE = 0
BOB = 1

In [7]:
%%time
import time

import crypten.mpc as mpc
import crypten.communicator as comm

transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

CPU times: user 89 µs, sys: 18 µs, total: 107 µs
Wall time: 110 µs


In [8]:
# labels = torch.load('/tmp/bob_test_labels.pth').long()
# count = 100 # For illustration purposes, we'll use only 100 samples for classification

@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    crypten.comm.get().set_verbosity(True)
    # Load pre-trained model to Alice
    model = crypten.load('../training/models/minionn/checkpoint_cpu_cpu.pt', dummy_model=model_ft, src=ALICE)
    
    # Encrypt model from Alice 
    dummy_input = torch.empty((1, 3, 32, 32))
    private_model = crypten.nn.from_pytorch(model.double(), dummy_input.double())

    private_model.encrypt(src=ALICE)
    
    # Load data to Bob
    data_enc = crypten.cryptensor(transform('../training/dataset/COVID/COVID-1.png').unsqueeze(0), src=BOB)

    # Classify the encrypted data
    private_model.eval()
    start = time.time()

    output_enc = private_model(data_enc)
    end = time.time()



    # Compute the accuracy
    output = output_enc.get_plain_text()
    print('Output class: ', torch.sigmoid(output), 'Time: ', end-start)
    crypten.print_communication_stats()
    
encrypt_model_and_data()

INFO:root:DistributedCommunicator with rank 0
INFO:root:DistributedCommunicator with rank 1
INFO:root:Added key: store_based_barrier_key:1 to store for rank: 1
INFO:root:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:2 to store for rank: 1
INFO:root:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:3 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:3 to store for rank: 1
INFO:root:Added key: store_based_barrier_key:4 to store for rank: 0
INFO:root:World size = 2
INFO:root:Added key: store_based_barrier_key:4 to store for rank: 1
INFO:root:World size = 2
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'tIME' 41 7
DEBUG:PIL.PngImagePlugin:b'tIME' 41 7 (unknown)
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 60 8192
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'tIME' 41 7
DEBUG:PIL.PngImagePlugin:b'tIME' 41 

Output class:  tensor([[0.8914]]) Time:  1.9540951251983643


INFO:root:====Communication Stats====
INFO:root:Rounds: 72


Output class: 

INFO:root:Bytes : 42309128


 

INFO:root:Comm time: 0.05111907500668167


tensor([[0.8914]]) Time:  1.958292007446289


INFO:root:====Communication Stats====
INFO:root:Rounds: 72
INFO:root:Bytes : 42309128
INFO:root:Comm time: 0.09995209000862815
INFO:root:DistributedCommunicator with rank 0
INFO:root:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:root:Added key: store_based_barrier_key:3 to store for rank: 0
INFO:root:World size = 1


[None, None]