In [1]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
import time

torch.manual_seed(73)

train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

class FullyConnectedNet(torch.nn.Module):
    def __init__(self, hidden=64, output=10):
        
        super(FullyConnectedNet, self).__init__()  
        self.fc1 = torch.nn.Linear(784, 1024)
        self.fc2 = torch.nn.Linear(1024, hidden)
        self.fc3 = torch.nn.Linear(hidden, output)

    def forward(self, x):
        x = x.view(-1, 784)
        # flattening while keeping the batch axis
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        x = x * x
        x = self.fc3(x)
        return x


def train(model, train_loader, criterion, optimizer, n_epochs=10):
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
    
    # model in evaluation mode
    model.eval()
    return model


model = FullyConnectedNet()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


start = time.time()
model = train(model, train_loader, criterion, optimizer, 10)
end = time.time()

print("model training takes",end - start,"s")

# PATH = "./mnist_cnn.pth"
# torch.save(model, PATH)






Epoch: 1 	Training Loss: 0.316995
Epoch: 2 	Training Loss: 0.184800
Epoch: 3 	Training Loss: 0.186411
Epoch: 4 	Training Loss: 0.129476
Epoch: 5 	Training Loss: 0.092804
Epoch: 6 	Training Loss: 0.387427
Epoch: 7 	Training Loss: 0.063582
Epoch: 8 	Training Loss: 0.102195
Epoch: 9 	Training Loss: 0.172201
Epoch: 10 	Training Loss: 0.130927
model training takes 425.63089323043823 s


In [2]:
def test(model, test_loader, criterion):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    # model in evaluation mode
    model.eval()

    for data, target in test_loader:
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1
        

    # calculate and print avg test loss
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

start = time.time()
test(model, test_loader, criterion)
end = time.time()

print("model testing takes",end - start,"s")

Test Loss: 0.377657

Test Accuracy of 0: 98% (970/980)
Test Accuracy of 1: 98% (1114/1135)
Test Accuracy of 2: 94% (980/1032)
Test Accuracy of 3: 99% (1001/1010)
Test Accuracy of 4: 94% (927/982)
Test Accuracy of 5: 94% (843/892)
Test Accuracy of 6: 98% (942/958)
Test Accuracy of 7: 96% (988/1028)
Test Accuracy of 8: 88% (864/974)
Test Accuracy of 9: 97% (986/1009)

Test Accuracy (Overall): 96% (9615/10000)
model testing takes 2.337642192840576 s


In [21]:
import tenseal as ts


class EncFullyConnectedNet:
    def __init__(self, torch_nn):
        
        self.fc1_weight = torch_nn.fc1.weight.T.data.tolist()
        self.fc1_bias = torch_nn.fc1.bias.data.tolist()
        
        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()
        
        self.fc3_weight = torch_nn.fc3.weight.T.data.tolist()
        self.fc3_bias = torch_nn.fc3.bias.data.tolist()
        
        
    def forward(self, enc_x):
        # conv layer
#         enc_channels = []
#         for kernel, bias in zip(self.conv1_weight, self.conv1_bias):
#             y = enc_x.conv2d_im2col(kernel, windows_nb) + bias
#             enc_channels.append(y)
#         # pack all channels into a single flattened vector
#         enc_x = ts.CKKSVector.pack_vectors(enc_channels)
        # square activation
#         enc_x.square_()
        # fc1 layer
        start = time.time()
        enc_x = enc_x.mm(self.fc1_weight) + self.fc1_bias
        end = time.time()
        print("fc1 takes", end - start)
        
        # square activation
        start = time.time()
        enc_x.square_()
        end = time.time()
        print("square activation takes", end - start)
        
        # fc2 layer
        start = time.time()
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        end = time.time()
        print("fc2 takes", end - start)
        
        start = time.time()
        enc_x.square_()
        end = time.time()
        print("square activation takes", end - start)
        
        # fc3 layer
        start = time.time()
        enc_x = enc_x.mm(self.fc3_weight) + self.fc3_bias
        end = time.time()
        print("fc3 takes", end - start)
        
        return enc_x
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    
def enc_test(context, model, test_loader, criterion):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    
    cnt = 0
    for data, target in test_loader:
        
#         Encoding and encryption
        print(data.view(-1,784).shape)
        x_enc = ts.ckks_vector(context, data.view(-1,784)[0])
#         x_enc = ts.CKKSVector.pack_vectors(vector)
        # Encrypted evaluation
        enc_output = enc_model(x_enc)
        # Decryption of result
        output = enc_output.decrypt()
        output = torch.tensor(output).view(1, -1)

        # compute loss
        loss = criterion(output, target)
        test_loss += loss.item()
        
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        label = target.data[0]
        class_correct[label] += correct.item()
        class_total[label] += 1
        cnt += 1
        if cnt == 50:
            break


    # calculate and print avg test loss
    test_loss = test_loss / sum(class_total)
    print(f'Test Loss: {test_loss:.6f}\n')
    
    print(class_correct)
    print(class_total)



# Load one element at a time
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
# required for encoding
# enc_x_test = [ts.ckks_vector(data.view(28, 28).tolist()) for x in x_test]
# kernel_shape = model.conv1.kernel_size
# stride = model.conv1.stride[0]

In [22]:
## Encryption Parameters

# controls precision of the fractional part
bits_scale = 26

# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)

# set the scale
context.global_scale = pow(2, bits_scale)

# galois keys are required to do ciphertext rotations
context.generate_galois_keys()

In [23]:
enc_model = EncFullyConnectedNet(model)
start = time.time()
enc_test(context, enc_model, test_loader, criterion)
end = time.time()
print("model testing takes",end - start,"s")

torch.Size([1, 784])
fc1 takes 45.57144021987915
square activation takes 0.0527040958404541
fc2 takes 45.67418575286865
square activation takes 0.03564119338989258
fc3 takes 1.0732967853546143
torch.Size([1, 784])
fc1 takes 60.14111924171448
square activation takes 0.0604250431060791
fc2 takes 37.30438280105591
square activation takes 0.03211045265197754
fc3 takes 0.884279727935791
torch.Size([1, 784])
fc1 takes 54.65120458602905
square activation takes 0.07305121421813965
fc2 takes 51.56203842163086
square activation takes 0.02988290786743164
fc3 takes 0.8409469127655029
torch.Size([1, 784])
fc1 takes 44.9692747592926
square activation takes 0.057878971099853516
fc2 takes 41.756065130233765
square activation takes 0.042777299880981445
fc3 takes 1.0879099369049072
torch.Size([1, 784])
fc1 takes 65.7329773902893
square activation takes 0.06248283386230469
fc2 takes 35.96977663040161
square activation takes 0.044799089431762695
fc3 takes 0.8356921672821045
torch.Size([1, 784])
fc1 takes 

fc1 takes 66.11450791358948
square activation takes 0.05590009689331055
fc2 takes 40.23895192146301
square activation takes 0.028014183044433594
fc3 takes 0.9329793453216553
torch.Size([1, 784])
fc1 takes 70.04368686676025
square activation takes 0.060019731521606445
fc2 takes 39.61665940284729
square activation takes 0.028958559036254883
fc3 takes 0.817812442779541
torch.Size([1, 784])
fc1 takes 52.97729206085205
square activation takes 0.05684685707092285
fc2 takes 57.4970383644104
square activation takes 0.026390790939331055
fc3 takes 1.105400562286377
torch.Size([1, 784])
fc1 takes 52.244409799575806
square activation takes 0.041884660720825195
fc2 takes 42.76051306724548
square activation takes 0.030915498733520508
fc3 takes 1.0813548564910889
torch.Size([1, 784])
fc1 takes 73.71582579612732
square activation takes 0.03586983680725098
fc2 takes 39.2504198551178
square activation takes 0.02382969856262207
fc3 takes 0.8622565269470215
torch.Size([1, 784])
fc1 takes 56.0594642162323
