#### Load the library

In [1]:
import torch
import torchvision
import crypten

crypten.init()
torch.set_num_threads(1)

# Set random seed for reproducibility
torch.manual_seed(1)

<torch._C.Generator at 0x7f4f3c0616d0>

#### Every user has a unique id associated with computation

In [2]:
#### Load MNIST dataset

In [3]:
n_epochs = 3
batch_size_train = 70000
batch_size_test = 20000
learning_rate = 0.01
momentum = 0.5
log_interval = 10


train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/tmp/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/tmp/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

#### Explore the data

In [4]:
test = enumerate(test_loader)
batch_idx, (test_data, test_targets) = next(test)
print(batch_idx, test_data.shape)

0 torch.Size([10000, 1, 28, 28])


#### Excrypt the MNIST data

In [5]:
from crypten import mpc

# Specify file locations to save each piece of data
filenames = {
    "train_data": "/tmp/traindata.pth",
    "train_labels": "/tmp/trainlabels.pth",
    "test_data": "/tmp/testdata.pth",
    "test_labels": "/tmp/testlabels.pth",
}

uid = 0

# @mpc.run_multiprocess(world_size=1)
def save_all_data():
    # Load all the MNIST train data
    datas = []
    labels = []
    for data,label in train_loader:
        datas.append(data)
        labels.append(label)
    # print(datas[0].shape, labels[0].shape)
    
    # Save train data, labels
    # crypten.save(datas[0], filenames["train_data"], src=uid)
    # crypten.save(labels[0], filenames["train_labels"], src=uid)
    cdatas = crypten.cryptensor(datas[0])
    clabels = crypten.cryptensor(labels[0])
    
    # print(type(cdatas._tensor._tensor))
    cdatas = cdatas._tensor._tensor
    clabels = clabels._tensor._tensor
    # print(clabels[0])
    # Save using normal pytorch
    # print(type(cdatas), type(clabels))
    torch.save(cdatas, filenames["train_data"])
    torch.save(clabels, filenames["train_labels"])
    
    # Load all the MNIST test data
    datas = []
    labels = []
    for data,label in test_loader:
        datas.append(data)
        labels.append(label)
    
    # Save train data, labels
    # crypten.save(datas[0], filenames["test_data"], src=uid)
    # crypten.save(labels[0], filenames["test_labels"], src=uid)
    
    cdatas = crypten.cryptensor(datas[0])
    clabels = crypten.cryptensor(labels[0])
    
    # print(type(cdatas._tensor._tensor))
    cdatas = cdatas._tensor._tensor
    clabels = clabels._tensor._tensor
    
    # Save using normal pytorch
    # print(type(cdatas), type(clabels))
    torch.save(cdatas, filenames["test_data"])
    torch.save(clabels, filenames["test_labels"])
    
    print('Save Successful')

    
save_all_data()

Save Successful


#### load the encrypted data

In [13]:
# one_hot_labels = F.one_hot(torch.arange(0, 10), num_classes=10)

label_dict = dict()
# norm_train = () 
# norm_test = ()

def load_train_data():
    # Load the encrypted data
    data_enc = torch.load(filenames["train_data"]).float()
    label_enc = torch.load(filenames["train_labels"])
    labels = torch.unique(label_enc).data.tolist()
    global label_dict
    label_dict = dict(zip(labels, range(0,10)))
    data_enc = data_enc / data_enc.max(0, keepdim=True)[0]
    # Find normalizing values
#     x0_mean = torch.mean(data_enc, axis=0)
#     x0_std = torch.std(data_enc, axis=0)
#     x1_mean = torch.mean(data_enc, axis=1)
#     x1_std = torch.std(data_enc, axis=1)
    
    # print(data_enc.shape)
    # Crypten way of loading
    # data_enc = crypten.load(filenames["train_data"], src=uid)
    # labels_enc = crypten.load(filenames["train_labels"], src=uid)
    return data_enc, label_enc

def load_test_data():
    # Load the encrypted data
    data_enc = torch.load(filenames["test_data"]).float()
    label_enc = torch.load(filenames["test_labels"])
    data_enc = data_enc / data_enc.max(0, keepdim=True)[0]
    # Crypten way of loading
    # data_enc = crypten.load(filenames["train_data"], src=uid)
    # labels_enc = crypten.load(filenames["train_labels"], src=uid)
    return data_enc, label_enc
            
# _,_ = load_train_data()    

In [7]:
# labels = range(10,20)
# label_dict = dict(zip(labels, range(0,10)))
# print(label_dict)


#### Create custum dataset to load the encrypted data

In [14]:
from torch.utils.data import Dataset, DataLoader

class EncryptedMNIST(Dataset):
  def __init__(self, train=True):
    if train:
        self.data, self.label = load_train_data()
    else:
        self.data, self.label = load_test_data()
         
  def __getitem__(self, index):
    # print(self.data.shape)
    img = self.data[index] # .float()
    label = self.label[index]
    # img = torchvision.transforms.Normalize(mean=(0.0, 0.0), std=(0.2, 0.2))(img)
    return img, label_dict[label.item()]
 
  def __len__(self):
    return  len(self.data)

#### Define loader for Encrypted data

In [15]:
train_encrypt_mnist = EncryptedMNIST()
train_loader_enc = torch.utils.data.DataLoader(train_encrypt_mnist, batch_size=256, shuffle=False)

test_encrypt_mnist = EncryptedMNIST(train=False)
test_loader_enc = torch.utils.data.DataLoader(test_encrypt_mnist, batch_size=1000, shuffle=False)


In [10]:
# print(label_dict)
# for data, label in train_loader_enc:
#     print(data)

#### Model to train encrypted MNIST

In [22]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

error = torch.nn.CrossEntropyLoss()

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # loss = F.nll_loss(output, target)
        loss = error(output, target)
        loss.backward()
        optimizer.step()
#         if batch_idx % 1000000== 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # test_loss = error(output, target)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


In [24]:
device = 'cuda:0'

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.01)

scheduler = StepLR(optimizer, step_size=1, gamma=0.8)
for epoch in range(1, 14 + 1):
    train(model, device, train_loader_enc, optimizer, epoch)
    test(model, device, test_loader_enc)
    scheduler.step()

if True:
    torch.save(model.state_dict(), "/tmp/mnist_cnn.pt")


Test set: Average loss: 1.3499, Accuracy: 7697/10000 (77%)


Test set: Average loss: 0.5533, Accuracy: 8636/10000 (86%)


Test set: Average loss: 0.4296, Accuracy: 8870/10000 (89%)


Test set: Average loss: 0.3798, Accuracy: 8965/10000 (90%)


Test set: Average loss: 0.3534, Accuracy: 9025/10000 (90%)


Test set: Average loss: 0.3359, Accuracy: 9068/10000 (91%)


Test set: Average loss: 0.3248, Accuracy: 9086/10000 (91%)


Test set: Average loss: 0.3167, Accuracy: 9106/10000 (91%)


Test set: Average loss: 0.3090, Accuracy: 9120/10000 (91%)


Test set: Average loss: 0.3051, Accuracy: 9131/10000 (91%)


Test set: Average loss: 0.3012, Accuracy: 9137/10000 (91%)


Test set: Average loss: 0.2984, Accuracy: 9142/10000 (91%)


Test set: Average loss: 0.2965, Accuracy: 9147/10000 (91%)


Test set: Average loss: 0.2948, Accuracy: 9157/10000 (92%)

