In [1]:
from time import time

import torch
import tenseal as ts
import numpy as np
from skimage.util.shape import view_as_windows

In [2]:
def almost_equal(vec1, vec2, m_pow_ten):
    if len(vec1) != len(vec2):
        return False

    upper_bound = pow(10, -m_pow_ten)
    for v1, v2 in zip(vec1, vec2):
        if abs(v1 - v2) > upper_bound:
            return False
    return True

In [3]:
def memory_strided_im2col(x, kernel):
    """
    Memory strided Image Block to Columns implementation
    """
    # Infer shapes
    x_h, x_w = x.shape
    k_h, k_w = kernel.shape
    # Assuming Padding=0, Stride=1
    out_h, out_w = (x.shape[0] - kernel.shape[0] + 1, x.shape[1] - kernel.shape[1] + 1)

    windows = view_as_windows(x, kernel.shape)
    return windows.reshape(out_h * out_w, k_h * k_w)

In [4]:
def torch_conv_2d(x, kernel):
    return torch.nn.functional.conv2d(
        input=x, weight=kernel, stride=1, padding=0, dilation=1
    )

In [11]:
def ckks_conv2d(x, kernel, context):
    # for each convolution layer, a communication between the client and server
    # is required. The server send the ciphertext (encrypted vector) to the client
    # which is the input of the next convolution layer, in order to decrypt it
    # and apply im2col (Image Block to Column) on the that input.
    new_x = memory_strided_im2col(x, kernel)
    print("new_x shape: ", new_x.shape)
    print(new_x)
    
    # after that the client encode and encrypt the input matrix in a vertical scan
    # (column-major) and send it back to the server.
    # new_x.flatten(order='F') is equivalent to new_x.T.flatten()
    x_enc = ts.ckks_vector(context, new_x.flatten(order='F').tolist())

    rows_number = new_x.shape[0]
    kernel_size = len(kernel.flatten().tolist())
    print("flatten_kernel_size: ", kernel_size)
    print("rows_number: ", rows_number)
    print("ckksvector size: ", x_enc.size())
    t = time()
    x_enc.mat_plain_vec_mult_inplace(kernel.flatten().tolist(), rows_number)
    t = time() - t
    print("time cost:", t)
    return x_enc

In [12]:
# input image dimension n * n
x_size = 4
# kernel dimension n * n
k_size = 2

# generated incremeneted values: 1, 2, ..., n^2
x = np.arange(1, x_size ** 2 + 1).reshape(x_size, x_size)
kernel = np.arange(1, k_size ** 2 + 1).reshape(k_size, k_size)

# generated random values
# x = np.random.rand(x_size, x_size)
# kernel = np.random.rand(k_size, k_size)

print("input", x.shape)
print(x)
print("kernel", kernel.shape)
print(kernel)


# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60]
)
# set the scale
context.global_scale = pow(2, 40)
# generated galois keys in order to do rotation on ciphertext vectors
context.generate_galois_keys()


y_enc = ckks_conv2d(x, kernel, context)
print(y_enc.size())
y_plain = y_enc.decrypt()

print("y_enc")
print(y_plain)


y_torch = torch_conv_2d(
    torch.from_numpy(x.astype("float32")).unsqueeze(0).unsqueeze(0),
    torch.from_numpy(kernel.astype("float32")).unsqueeze(0).unsqueeze(0),
)
y_torch = y_torch.flatten().numpy()
print("y_toch")
print(y_torch)

assert almost_equal(y_plain, y_torch, 0)

input (4, 4)
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]
kernel (2, 2)
[[1 2]
 [3 4]]
new_x shape:  (9, 4)
[[ 1  2  5  6]
 [ 2  3  6  7]
 [ 3  4  7  8]
 [ 5  6  9 10]
 [ 6  7 10 11]
 [ 7  8 11 12]
 [ 9 10 13 14]
 [10 11 14 15]
 [11 12 15 16]]
flatten_kernel_size:  4
rows_number:  9
ckksvector size:  36
time cost: 0.014109134674072266
9
y_enc
[44.000006357476, 54.00000741829126, 64.00000852519227, 84.00001123350282, 94.00001262979303, 104.00001397124595, 124.00001661905705, 134.00001796757925, 144.00001931391517]
y_toch
[ 44.  54.  64.  84.  94. 104. 124. 134. 144.]
