In [1]:
import torch
from torchvision import datasets, transforms


torch.manual_seed(73)

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=1, shuffle=True)

In [2]:
import torch.nn as nn
import torch.optim as optim


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        
    def forward(self, x):
        out = x.reshape(-1, 28*28)
        out = self.fc1(out)
        out = out * out
        out = self.fc2(out)
        out = out * out
        out = self.fc3(out)
        out = nn.functional.log_softmax(out, dim=1)
        return out

In [3]:
def train(model, device, train_loader, optimizer, epochs):
    model.train()
    for epoch in range(1, epochs + 1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
    model.eval()
    return model

In [4]:
model = Model()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = train(model, device, train_loader, optimizer, 10)



In [5]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [6]:
test(model, device, test_loader)


Test set: Average loss: 0.3157, Accuracy: 9197/10000 (92%)



In [7]:
import tenseal as ts

class HEModel:
    def __init__(self, ts_context, fc1, fc2, fc3):
        self.ts_context = ts_context
        self.fc1_weight = fc1.weight.t().tolist()
        self.fc1_bias = fc1.bias.tolist()
        self.fc2_weight = fc2.weight.t().tolist()
        self.fc2_bias = fc2.bias.tolist()
        self.fc3_weight = fc3.weight.t().tolist()
        self.fc3_bias = fc3.bias.tolist()
        
    def forward(self, x):
        batch_size = x.shape[0]
        out = []
        for i in range(batch_size):
            # reshape single image into vector and encrypt it
            vec = x[i].flatten().tolist()
            encrypted_vec = ts.ckks_vector(self.ts_context, vec)
            # first fc layer + square activation function
            encrypted_vec = encrypted_vec.mm(self.fc1_weight) + self.fc1_bias
            encrypted_vec *= encrypted_vec
            # second fc layer + square activation function
            encrypted_vec = encrypted_vec.mm(self.fc2_weight) + self.fc2_bias
            encrypted_vec *= encrypted_vec
            # third fc layer
            encrypted_vec = encrypted_vec.mm(self.fc3_weight) + self.fc3_bias
            # decrypt
            out.append(encrypted_vec.decrypt())
            
        out = torch.tensor(out)
        out = nn.functional.log_softmax(out, dim=1)
        return out
    
    def __call__(self, x):
        return self.forward(x)


In [8]:
context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[40, 21, 21, 21, 21, 21, 40])
context.global_scale = 2 ** 21
context.generate_galois_keys()
he_model = HEModel(context, model.fc1, model.fc2, model.fc3)

In [9]:
from time import time

t_start = time()

test_loss = 0
correct = 0
count = 120
i = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = he_model(data)
        test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        i += 1
        if i == count:
            break
        
test_loss /= count

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, count,
    100. * correct / count))

t_end = time()
print(f"Evaluation of {count} pictures took {(t_end - t_start) / 60} min")


Test set: Average loss: 1.0815, Accuracy: 113/120 (94%)

Evaluation of 120 pictures took 31.11869985659917 min


## Possible improvement

- Decrease hidden layers (256->128 and 64->32)
- Use smaller poly_modulus => use coeff mod [60, 40, 40, 60] with a scale of 20 so that we do 2 multiplication before rescaling
- restructure the diagonals is a single plaintext