In [1]:
import math
import torch
import tenseal as ts
import numpy as np
from skimage.util.shape import view_as_windows

In [2]:
def almost_equal(vec1, vec2, m_pow_ten):
    if len(vec1) != len(vec2):
        return False

    upper_bound = pow(10, -m_pow_ten)
    for v1, v2 in zip(vec1, vec2):
        if abs(v1 - v2) > upper_bound:
            return False
    return True

### Torch 2d convolution

In [3]:
def torch_conv_2d(x, kernel, stride):
    return torch.nn.functional.conv2d(
        input=x, weight=kernel, stride=stride, padding=0, dilation=1
    )

### Memory strided Image Block to Columns implementation

In [4]:
def memory_strided_im2col(x, kernel_shape, stride=1):
    # Infer shapes
    x_h, x_w = x.shape
    k_h, k_w = kernel_shape
    # Assuming Padding=0, Stride=1
    out_h, out_w = ((x.shape[0] - kernel_shape[0])//stride + 1,
                    (x.shape[1] - kernel_shape[1])//stride + 1)

    windows = view_as_windows(x, kernel_shape, step=stride)
    return windows.reshape(out_h * out_w, k_h * k_w)

### Generate input

In [5]:
# input image dimension n * n
x_size = 4
# kernel dimension n * n
k_size = 3
# stride
stride = 1

# generated incremeneted values: 1, 2, ..., n^2
x = np.arange(1, x_size ** 2 + 1).reshape(x_size, x_size)
kernel = np.arange(1, k_size ** 2 + 1).reshape(k_size, k_size)

# generated random values
# x = np.random.randn(x_size, x_size)
# kernel = np.random.randn(k_size, k_size)

print("input", x.shape)
print(x)
print("kernel", kernel.shape)
print(kernel)

input (4, 4)
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]
kernel (3, 3)
[[1 2 3]
 [4 5 6]
 [7 8 9]]


### TenSEAL context

In [6]:
# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60]
)
# set the scale
context.global_scale = pow(2, 40)
# generated galois keys in order to do rotation on ciphertext vectors
context.generate_galois_keys()

For each convolution layer, a communication between the client and server is required. The server send the ciphertext (encrypted vector) to the client which is the input of the next convolution layer, in order to decrypt it and apply im2col (Image Block to Column) on the that input.

In [7]:
new_x = memory_strided_im2col(x, kernel.shape, stride)
print("new_x shape: ", new_x.shape)
print(new_x)

new_x shape:  (4, 9)
[[ 1  2  3  5  6  7  9 10 11]
 [ 2  3  4  6  7  8 10 11 12]
 [ 5  6  7  9 10 11 13 14 15]
 [ 6  7  8 10 11 12 14 15 16]]


After that the client encode and encrypt the input matrix in a vertical scan (column-major) and send it back to the server.

new_x.flatten(order='F') is equivalent to new_x.T.flatten()


In [8]:
%%time 
# pad the input
print(kernel.size)
next_power2 = pow(2, math.ceil(math.log2(kernel.size)))
pad_width = next_power2 - kernel.size
padded_x = np.pad(new_x, ((0, 0), (0, pad_width)))
print(padded_x.shape)
print(padded_x)

x_enc = ts.ckks_vector(context, padded_x.flatten(order='F').tolist())
windows_nb = padded_x.shape[0]
print("flatten_kernel_size: ", kernel.size)
print("windows number: ", windows_nb)
print("ckksvector size: ", x_enc.size())

# pad the kernel
padded_kernel = np.pad(kernel.flatten(), (0, pad_width))

y_enc = x_enc.conv2d_im2col(padded_kernel.tolist(), windows_nb)

print(y_enc.size())
y_plain = y_enc.decrypt()

print("y_enc")
print(y_plain)

9
(4, 16)
[[ 1  2  3  5  6  7  9 10 11  0  0  0  0  0  0  0]
 [ 2  3  4  6  7  8 10 11 12  0  0  0  0  0  0  0]
 [ 5  6  7  9 10 11 13 14 15  0  0  0  0  0  0  0]
 [ 6  7  8 10 11 12 14 15 16  0  0  0  0  0  0  0]]
flatten_kernel_size:  9
windows number:  4
ckksvector size:  64
4
y_enc
[348.000048024785, 393.00005275438355, 528.0000706569963, 573.0000769331538]
CPU times: user 41.8 ms, sys: 0 ns, total: 41.8 ms
Wall time: 40.7 ms


### Compare the result to torch conv2d

In [9]:
y_torch = torch_conv_2d(
    torch.from_numpy(x.astype("float32")).unsqueeze(0).unsqueeze(0),
    torch.from_numpy(kernel.astype("float32")).unsqueeze(0).unsqueeze(0),
    stride
)
y_torch = y_torch.flatten().numpy()
print("y_toch")
print(y_torch)

assert almost_equal(y_plain, y_torch, 0)

y_toch
[348. 393. 528. 573.]
