In [1]:
import io
import numpy as np
import onnx
import onnxruntime
import tenseal as ts
import torch

from onnx import helper
from onnx import numpy_helper
from onnx import TensorProto
from torch import nn
from torch.onnx import export as export_onnx

In [2]:
input_length = 32
OPSET_VERSION = 14

In [3]:
model = onnx.load("../tenseal-inference/models/lenet-5_square.onnx")

In [5]:
### def: conv
def convlayer(x,w,b,in_shape,node_nr):
    atts = {}
    for a in model.graph.node[node_nr].attribute:
        if a.name == "group":
            atts[a.name] = a.i
        else:
            atts[a.name] = a.ints
            
    n_channels_in = in_shape[1]
    in_shape = in_shape[2:]
    n_dims = len(in_shape)
    n_channels_out = w.shape[0]
    
    out_shape = [
        int(
            float(
                in_shape[i]
                + sum(atts["pads"][i::n_dims])
                - atts["dilations"][i] * (atts["kernel_shape"][i] - 1)
                - 1
            )
            / float(atts["strides"][i])
            + 1
        )
        for i in range(n_dims)
    ]
    
    X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_channels_in] + in_shape)
    Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_channels_out] + out_shape)

    # Create Conv node
    node_def = helper.make_node(
        "Conv",  # node name
        ["X", "W"],  # inputs (B is optional and dropped as it has to be 0)
        ["Y"],  # outputs
        dilations=atts["dilations"],
        group=atts["group"],
        kernel_shape=atts["kernel_shape"],
        pads=atts["pads"],
        strides=atts["strides"],
    )

    weight_init = onnx.helper.make_tensor(
        name="W",
        data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[w.dtype],
        dims=w.shape,
        vals=w.flatten(),
    )

    # Create the graph (GraphProto)
    graph_def = helper.make_graph([node_def], "conv-model", [X], [Y], initializer=[weight_init])

    model_def = helper.make_model(graph_def)
    model_def.opset_import[0].version = OPSET_VERSION
    onnx.checker.check_model(model_def)
    buffer = io.BytesIO()
    onnx.save(model_def, buffer)
    session = onnxruntime.InferenceSession(buffer.getvalue())

    full_in_shape = [1, n_channels_in, input_length, input_length]

    # run convolution with identity matrix as input to get convolution matrix
    eye = np.eye(
        np.prod(full_in_shape[1:]), dtype=w.dtype
    ).reshape(
        [np.prod(full_in_shape[1:]), n_channels_in]
        + in_shape
    )

    conv_matrix = np.array(session.run(None, {"X": eye})).reshape(
        np.prod(full_in_shape[1:]),
        n_channels_out * np.prod(out_shape),
    )
    
    print(f"conv matrix: {conv_matrix.shape}")

    bias_conv = np.repeat(b, np.prod(out_shape))

    # result = x @ conv_matrix + bias_conv
    
    return conv_matrix, bias_conv

In [22]:
### def: averagepool
def avgpoollayer(in_shape,node_nr,kernel_shape):
    atts = {}
    for a in model.graph.node[node_nr].attribute:
        if a.name == "ceil_mode" or a.name == "count_include_pad":
            atts[a.name] = a.i
        else:
            atts[a.name] = a.ints

    atts["count_include_pad"] = 0
    
    ###
    ###
    atts["kernel_shape"] = kernel_shape
    atts["strides"] = [2,2]
    ###
    ###
    
    dims = in_shape[2:]
    n_dims = len(dims)
    n_channels = in_shape[1]

    out_shape = [
        float(dims[i] + sum(atts["pads"][i::n_dims]) - atts["kernel_shape"][i]) / float(atts["strides"][i]) + 1
        for i in range(n_dims)
    ]

    if atts["ceil_mode"]:
        out_shape = [math.ceil(s) for s in out_shape]
    else:
        out_shape = [int(s) for s in out_shape]

    X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_channels] + dims)
    Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_channels] + out_shape)

    # Create Conv node
    node_def = helper.make_node(
        "AveragePool",
        inputs=["X"],
        outputs=["Y"],
        ceil_mode=atts["ceil_mode"],
        count_include_pad=atts["count_include_pad"],
        kernel_shape=atts["kernel_shape"],
        pads=atts["pads"],
        strides=atts["strides"]
    )

    # Create the graph (GraphProto)
    graph_def = helper.make_graph([node_def], "avgpool-model", [X], [Y])

    model_def = helper.make_model(graph_def)
    model_def.opset_import[0].version = OPSET_VERSION
    onnx.checker.check_model(model_def)
    buffer = io.BytesIO()
    onnx.save(model_def, buffer)
    session = onnxruntime.InferenceSession(buffer.getvalue())


    eye = np.eye(np.prod(in_shape[1:]), dtype=np.float32).reshape(
        [np.prod(in_shape[1:]), n_channels] + dims
    )

    M_avgpool = np.array(session.run(None, {"X": eye})).reshape(
        np.prod(in_shape[1:]), n_channels * np.prod(out_shape)
    )
    
    return M_avgpool

In [49]:
### context erstellen
bits_scale = 32
    
# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=16384,
    coeff_mod_bit_sizes=[43] + [bits_scale]*11 + [43]
)
context.global_scale = 2**bits_scale
context.generate_galois_keys()

In [3]:
def ascii_plot(img, char=" ", outer="X"):
    for row in img:
        for pixel in row:
            if pixel > .5:
                print(char, end="")
            else:
                print(outer, end="")
        print("")
        
        
# def mnist_plt(img):
#     img = np.matrix(img)
#     plt.imshow(img, cmap="gray")
#     plt.axis('off')
#     plt.show()

In [12]:
### plaintext verschlüsseln
pt = np.load("../../../../tenseal-inference/tmp/mnist_32x32_data/51.npy")

ascii_plot(pt[0][0])
    
# pt = pt.flatten()
# ct = ts.ckks_vector(context, pt)

XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXX          XXXXXXXXXXXXXX
XXXXX               XXXXXXXXXXXX
XXXXX                XXXXXXXXXXX
XXXXX       XXXX     XXXXXXXXXXX
XXXXXXXXXXXXXXXXX    XXXXXXXXXXX
XXXXXXXXXXXXXXXX     XXXXXXXXXXX
XXXXXXXXXXXXXXX      XXXXXXXXXXX
XXXXXXXXXXXX          XXXXXXXXXX
XXXXXXXXX              XXXXXXXXX
XXXXXXXXX                 XXXXXX
XXXXXXXXX      XXXXX       XXXXX
XXXXXXXXXXXXXXXXXXXXXX     XXXXX
XXXXXXXXXXXXXXXXXXXXXXX    XXXXX
XXXXXXXXXXXXXXXXXXXXXXXX   XXXXX
XXXXXXXXXXXXXXXXXXXXXXXX   XXXXX
XXXXXXXXXX  XXXXXXXXXX     XXXXX
XXXXXXXXXX       X        XXXXXX
XXXXXXXXXX               XXXXXXX
XXXXXXXXXXXX            XXXXXXXX
XXXXXXXXXXXXXXXXX   XXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
XXXXXXXXXX

In [9]:
### Go through onnx
# initializers = model.graph.initializer
initializers = {
    init.name : numpy_helper.to_array(init) for init in model.graph.initializer
}

In [10]:
### 0: Conv

## Calculate matrix and bias

node = model.graph.node[0]
print(node.name)

w = initializers[node.input[1]]
b = initializers[node.input[2]]

shape = [d.dim_value for d in model.graph.input[0].type.tensor_type.shape.dim]

m,b = convlayer(pt, w, b, shape, 0)
print(m.shape, b.shape)

Conv_0
conv matrix: (1024, 4704)
(1024, 4704) (4704,)


In [11]:
# Conv
# cleartext matrix mult
y = pt @ m + b
# encrypted matrix mult
y_ct = ct @ m + b

In [12]:
# ciphertext result
np.array(y_ct.decrypt())

array([ 0.12274845,  0.12274784,  0.12274706, ..., -0.06017368,
       -0.06017292, -0.06017332])

In [13]:
# cleartext result
np.array(y)

array([ 0.12274782,  0.12274782,  0.12274782, ..., -0.06017318,
       -0.06017318, -0.06017318], dtype=float32)

In [14]:
# max absolute error
np.abs(np.array(y_ct.decrypt() - np.array(y))).max()

0.00010509137386516265

In [15]:
### 1: Mul
y_2 = y**2
y_ct_2 = y_ct ** 2

In [16]:
# ciphertext result
np.array(y_ct_2.decrypt())

array([0.01507209, 0.01507068, 0.01507037, ..., 0.00362204, 0.00362185,
       0.00362177])

In [17]:
# cleartext result
np.array(y_2)

array([0.01506703, 0.01506703, 0.01506703, ..., 0.00362081, 0.00362081,
       0.00362081], dtype=float32)

In [18]:
# max absolute error
np.abs(np.array(y_ct_2.decrypt() - np.array(y_2))).max()

0.0003137399548880371

In [19]:
### 2: Const

In [20]:
### 3: Pad

In [23]:
### 4: AveragePool

## calculate matrix

node = model.graph.node[4]
print(node.name)
shape = [1, 6, 28, 28]

m_avgpool = avgpoollayer(shape, 4)
print(m_avgpool.shape)

AveragePool_4
(4704, 726)


In [24]:
# AveragePool
# perform matrix multiplication
y_3 = y_2 @ m_avgpool
y_ct_3 = y_ct_2 @ m_avgpool

In [25]:
# ciphertext result
np.array(y_ct_3.decrypt())[:10]

array([0.01403442, 0.01636801, 0.01911621, 0.02267523, 0.0259698 ,
       0.02528922, 0.02142157, 0.01634411, 0.01165139, 0.01288812,
       0.01441646, 0.05319747, 0.08660934, 0.11894748, 0.13593958,
       0.13755851, 0.11698975, 0.08629081, 0.05362647, 0.03161544,
       0.02072492, 0.01655304, 0.08065655, 0.13082684, 0.16738136,
       0.17517408, 0.1616479 , 0.13498461, 0.11807406, 0.0884773 ,
       0.06423259, 0.03804512, 0.01791253, 0.08209787, 0.13134078,
       0.16695143, 0.1744439 , 0.16574848, 0.15339788, 0.15209949,
       0.12248668, 0.0893013 , 0.04279055, 0.01587027, 0.06862564,
       0.10457203, 0.13231481, 0.14626585, 0.16065897, 0.18805962,
       0.21871524, 0.19794046, 0.14792396, 0.06943561, 0.02032446,
       0.02411983, 0.04921017, 0.09271947, 0.14221408, 0.18552384,
       0.21975038, 0.24051165, 0.22137304, 0.17949023, 0.10601275,
       0.05245037, 0.01661472, 0.04411599, 0.09233953, 0.14790835,
       0.19166757, 0.21325054, 0.21289053, 0.19488934, 0.16338

In [26]:
# cleartext result
np.array(y_3)[:10]

array([0.01402482, 0.01635544, 0.01909945, 0.02265636, 0.02594814,
       0.02526968, 0.02140413, 0.01633215, 0.011643  , 0.01288001,
       0.01440761, 0.05315443, 0.08653777, 0.11884884, 0.13582784,
       0.13744381, 0.11689145, 0.08621946, 0.05358465, 0.03159159,
       0.02071063, 0.01654229, 0.08059102, 0.13071921, 0.16724077,
       0.17502743, 0.1615144 , 0.13487305, 0.11797799, 0.08840645,
       0.06418159, 0.03801695, 0.01790052, 0.08203217, 0.13123341,
       0.16681375, 0.17429808, 0.16561027, 0.15327217, 0.15197545,
       0.12238874, 0.08922977, 0.04275911, 0.01585981, 0.06857053,
       0.10448653, 0.1322063 , 0.14614493, 0.1605257 , 0.18790399,
       0.21853498, 0.19777757, 0.14780328, 0.06938038, 0.02030989,
       0.02410257, 0.04917143, 0.09264371, 0.14209697, 0.18536943,
       0.21956788, 0.24031147, 0.22119118, 0.17934376, 0.10592797,
       0.05240904, 0.01660387, 0.04408179, 0.09226432, 0.14778702,
       0.19150919, 0.21307456, 0.21271527, 0.19472972, 0.16325

In [27]:
# max absolute error
np.abs(np.array(y_ct_3.decrypt() - np.array(y_3))).max()

0.00020886406581244188

In [7]:
def avgpool_matrix(in_shape, kernel_size):

    
    dims = in_shape[2:]
    n_dims = len(dims)
    n_channels = in_shape[1]

    out_shape = [
        int(
            float(dims[i] - kernel_size) / float(2) + 1
        )
        for i in range(n_dims)
    ]

    X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_channels] + dims)
    Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_channels] + out_shape)

    # Create Conv node
    node_def = helper.make_node(
        "AveragePool",
        inputs=["X"],
        outputs=["Y"],
        ceil_mode=0,
        count_include_pad=0,
        kernel_shape=[kernel_size, kernel_size],
        pads=[0,0,0,0],
        strides=[2,2]
    )

    # Create the graph (GraphProto)
    graph_def = helper.make_graph([node_def], "avgpool-model", [X], [Y])

    model_def = helper.make_model(graph_def)
    model_def.opset_import[0].version = OPSET_VERSION
    onnx.checker.check_model(model_def)
    buffer = io.BytesIO()
    onnx.save(model_def, buffer)
    session = onnxruntime.InferenceSession(buffer.getvalue())


    eye = np.eye(np.prod(in_shape[1:]), dtype=np.float32).reshape(
        [np.prod(in_shape[1:]), n_channels] + dims
    )

    M_avgpool = np.array(session.run(None, {"X": eye})).reshape(
        np.prod(in_shape[1:]), n_channels * np.prod(out_shape)
    )
    
    return M_avgpool