In [3]:
import math
import torch
import tenseal as ts
import numpy as np
from skimage.util.shape import view_as_windows

In [4]:
def almost_equal(vec1, vec2, m_pow_ten):
    if len(vec1) != len(vec2):
        return False

    upper_bound = pow(10, -m_pow_ten)
    for v1, v2 in zip(vec1, vec2):
        if abs(v1 - v2) > upper_bound:
            return False
    return True

### Torch 2d convolution

In [5]:
def torch_conv_2d(x, kernel, stride):
    return torch.nn.functional.conv2d(
        input=x, weight=kernel, stride=stride, padding=0, dilation=1
    )

### Generate input

In [1]:
# input image dimension n * n
x_size = 3
# kernel dimension n * n
k_size = 2
# stride
stride = 1

# # generated incremeneted values: 1, 2, ..., n^2
# x = np.arange(1, x_size ** 2 + 1).reshape(x_size, x_size)
# kernel = np.arange(1, k_size ** 2 + 1).reshape(k_size, k_size)

# generated random values
x = np.random.randint(0, 10, size=(x_size, 1))
kernel = np.random.randint(0, 10, size=(k_size, 1))

print("input", x.shape)
print(x)
print("kernel", kernel.shape)
print(kernel)

NameError: name 'np' is not defined

### TenSEAL context

In [36]:
# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60]
)
# set the scale
context.global_scale = pow(2, 40)
# generated galois keys in order to do rotation on ciphertext vectors
context.generate_galois_keys()

For each convolution layer, a communication between the client and server is required. The server send the ciphertext (encrypted vector) to the client which is the input of the next convolution layer, in order to decrypt it and apply im2col (Image Block to Column) on the that input.

After that the client encode and encrypt the input matrix in a vertical scan (column-major) and send it back to the server.

In [None]:
%%time 

x_enc, windows_nb = ts.im2col_encoding(context, x, kernel.shape[0], kernel.shape[1], stride)

print("windows number: ", windows_nb)
print("ckksvector size: ", x_enc.size())

y_enc = x_enc.conv2d_im2col(kernel.tolist(), windows_nb)

print(y_enc.size())
y_plain = y_enc.decrypt()

print("y_enc")
print(y_plain)

### Compare the result to torch conv2d

In [34]:
y_torch = torch_conv_2d(
    torch.from_numpy(x.astype("float32")).unsqueeze(0).unsqueeze(0),
    torch.from_numpy(kernel.astype("float32")).unsqueeze(0).unsqueeze(0),
    stride
)
y_torch = y_torch.flatten().numpy()
print("y_toch")
print(y_torch)

assert almost_equal(y_plain, y_torch, 0)

y_toch
[ -54.58763     87.4021      47.199284    34.291573  -111.8195
    6.2725477  100.97608   -106.42866    -31.869497    43.06913
  110.91937     69.30648     76.14888    177.25304    105.908844
  -21.30675     12.912507   -27.024754   -16.25838    -47.6781
   81.64753    137.79376     91.30269    -94.195076  -196.98035
 -125.99463     42.880825  -131.80736    -78.094765    59.072765
   69.54799    -77.53858    -67.61382    -79.19424    -82.42037
  -55.323532    21.918015    81.582664    59.006775   -86.694984
  -50.26876    -30.892807     9.307717  -133.45465    -92.55416
  -23.359943    23.213211    -2.048564   -80.353806   -11.300457
    3.8368444   36.394527    -8.637391     5.187211    44.294132
   49.788925    35.600857   -23.631056    46.156494   -38.97396
   56.187077    64.65246    -97.72869   -177.6504     -57.25484
    8.721375    81.09992     31.083414     3.267823   -94.50235
 -157.55186      1.2791833  -48.5475    -114.12285     75.05805
   54.170834    12.75259     1