Precision comparison
1. onnx cleartext ("onnxruntime")
2. torch.nn model cleartext ("pytorch")
3. matrix vector mult cleartext ("numpy")
4. encrypted matrix vector mult ("tenseal")

**Result:**

unfortunately only float precision - but otherwise computation is precise

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import onnx
from onnx import helper
from onnx import numpy_helper
from onnx import TensorProto
import onnxruntime
import tenseal as ts
import torch
from torch import nn
from torch.onnx import export as export_onnx

input_length = 28
OPSET_VERSION = 14

In [None]:
# 2D Conv
n_inputs = input_length * input_length
n_channels_in = 2
n_channels_out = 4
kernel_size = 2
stride = 2
model = nn.Conv2d(n_channels_in, n_channels_out, kernel_size, stride)

export_onnx(
    model,
    torch.empty([1, n_channels_in, input_length, input_length]),
    "./conv.onnx",
    opset_version=OPSET_VERSION,
)

In [None]:
model_path = "./conv.onnx"

model = onnx.load(model_path)

In [None]:
# input for all tests
np.random.seed(27)
x = np.asarray(np.random.random([1,n_channels_in,input_length,input_length]), np.float32)

w = numpy_helper.to_array(model.graph.initializer[0])
b = numpy_helper.to_array(model.graph.initializer[1])

result_dict = {}

In [None]:
# 1 - direct onnxruntime

session = onnxruntime.InferenceSession(model_path)
y = session.run(None, {model.graph.input[0].name: x})

result_dict["onnxruntime"] = np.array(y).ravel()

In [None]:
# 2 - nn.Conv2d
conv_layer = nn.Conv2d(1,4,7,stride=3, dtype=torch.double)
conv_layer.weight.data = torch.tensor(w)
conv_layer.bias.data = torch.tensor(b)

x2 = torch.tensor(x)
y2 = conv_layer(x2)

result_dict["pytorch"] = y2.detach().numpy().ravel()

In [None]:
## compute M_conv and bias_conv

n_channels_in = 2
n_channels_out = 4
in_shape = [input_length, input_length]
n_dims = len(in_shape)

atts = {}
for a in model.graph.node[0].attribute:
    if a.name == "group":
        atts[a.name] = a.i
    else:
        atts[a.name] = a.ints
        
out_shape = [
    int(
        float(
            in_shape[i]
            + sum(atts["pads"][i::n_dims])
            - atts["dilations"][i] * (atts["kernel_shape"][i] - 1)
            - 1
        )
        / float(atts["strides"][i])
        + 1
    )
    for i in range(n_dims)
]      

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, n_channels_in] + in_shape)
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, n_channels_out] + out_shape)

# Create Conv node
node_def = helper.make_node(
    "Conv",  # node name
    ["X", "W"],  # inputs (B is optional and dropped as it has to be 0)
    ["Y"],  # outputs
    dilations=atts["dilations"],
    group=atts["group"],
    kernel_shape=atts["kernel_shape"],
    pads=atts["pads"],
    strides=atts["strides"],
)

weight_init = onnx.helper.make_tensor(
    name="W",
    data_type=onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[w.dtype],
    dims=w.shape,
    vals=w.flatten(),
)

# Create the graph (GraphProto)
graph_def = helper.make_graph([node_def], "conv-model", [X], [Y], initializer=[weight_init])

model_def = helper.make_model(graph_def)
model_def.opset_import[0].version = OPSET_VERSION
onnx.checker.check_model(model_def)
buffer = io.BytesIO()
onnx.save(model_def, buffer)
session = onnxruntime.InferenceSession(buffer.getvalue())

full_in_shape = [1, n_channels_in, input_length, input_length]

# run convolution with identity matrix as input to get convolution matrix
eye = np.eye(
    np.prod(full_in_shape[1:]), dtype=w.dtype
).reshape(
    [np.prod(full_in_shape[1:]), n_channels_in]
    + in_shape
)

conv_matrix = np.array(session.run(None, {"X": eye})).reshape(
    np.prod(full_in_shape[1:]),
    n_channels_out * np.prod(out_shape),
)

bias_conv = np.repeat(b, np.prod(out_shape))

x3 = x.reshape(1,-1)
y3 = x3 @ conv_matrix + bias_conv

result_dict["numpy"] = y3.ravel()

In [None]:
## 3 - clear matrix vector multiplication

x3 = x.reshape(1,-1)
y3 = x3 @ conv_matrix + bias_conv

result_dict["numpy"] = y3.ravel()

In [None]:
## 4 - encrypted matrix vector multiplication
bits_scale = 44
    
# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[60, bits_scale, 60]
)
context.global_scale = pow(2, bits_scale)
context.generate_galois_keys()

x4 = x.ravel()
x_enc = ts.ckks_vector(context, x4)

y4 = x_enc @ conv_matrix + bias_conv

result_dict["tenseal"] = np.array(y4.decrypt())

In [None]:
for i,(k,v) in enumerate(result_dict.items()):
    for i2,(k2,v2) in enumerate(result_dict.items()):
        if i2 > i:
            delta = v - v2
            max_err = max(abs(delta))

            print(f"max error: {max_err}", flush=True)
            plt.hist(delta, "auto")
            plt.title(f"({k}) vs ({k2})")
            plt.show()