In [3]:
from FourierCKKS import *
import numpy as np
from fft.fft import linear_convolution_direct, linear_convolution2d_direct

# 1) Ensure constructor rejects non-square data_len
try:
    FourierCKKS(poly_degree=1024)
except ValueError as e:
    print("Correctly raised exception for non-square data_len:", e)

# Use poly_degree=512 => data_len=256 => img_side=16
fcks = FourierCKKS(poly_degree=512)

# ------- 1D Tests -------
m1 = np.random.random(300) + 1j*np.random.random(300)
m2 = np.random.random(200) + 1j*np.random.random(200)
conv_len = 300 + 200 - 1

ct1 = fcks.forward(m1, target_length=conv_len)
ct2 = fcks.forward(m2, target_length=conv_len)
r1 = fcks.backward(ct1, target_length=300)
r2 = fcks.backward(ct2, target_length=200)
print("1D Recovery:", np.allclose(r1, m1, atol=1e-1), np.allclose(r2, m2, atol=1e-1))

ct_sum = fcks.cipher_add(ct1, ct2)
r_sum = fcks.backward(ct_sum, target_length=conv_len)
t_sum = np.zeros(conv_len, dtype=complex)
t_sum[:300] += m1
t_sum[:200] += m2
print("1D Addition:", np.allclose(r_sum, t_sum, atol=1e-1))

ct_conv = fcks.cipher_conv(ct1, ct2)
r_conv = fcks.backward(ct_conv, target_length=conv_len)
t_conv = linear_convolution_direct(m1, m2)
print("1D Convolution:", np.allclose(r_conv, t_conv, atol=1e-1))

Correctly raised exception for non-square data_len: data_len=512 is not a perfect square; cannot support 2D FFT packing.
1D Recovery: True True
1D Addition: True
1D Convolution: True


In [6]:
# ------- 2D Tests -------
img = np.random.random((30, 40)) + 1j*np.random.random((30, 40))
ker = np.random.random((10, 15)) + 1j*np.random.random((10, 15))
out_h = 30 + 10 - 1
out_w = 40 + 15 - 1

ct_img = fcks.forward(img, target_height=out_h, target_width=out_w)
ct_ker = fcks.forward(ker, target_height=out_h, target_width=out_w)

r_img = fcks.backward(ct_img, target_height=out_h, target_width=out_w)
r_ker = fcks.backward(ct_ker, target_height=out_h, target_width=out_w)
print("2D Recovery:", np.allclose(r_img[:30, :40], img, atol=1e-1), np.allclose(r_ker[:10, :15], ker, atol=1e-1))

ct_sum2 = fcks.cipher_add(ct_img, ct_ker)
r_sum2 = fcks.backward(ct_sum2, target_height=out_h, target_width=out_w)
t_sum2 = np.zeros((out_h, out_w), dtype=complex)
t_sum2[:30, :40] += img
t_sum2[:10, :15] += ker
print("2D Addition:", np.allclose(r_sum2, t_sum2, atol=1e-1))

ct_conv2 = fcks.cipher_conv(ct_img, ct_ker)
r_conv2 = fcks.backward(ct_conv2, target_height=out_h, target_width=out_w)
t_conv2 = linear_convolution2d_direct(img, ker)
print("2D Convolution:", np.allclose(r_conv2, t_conv2, atol=1e-1))


2D Recovery: True True
2D Addition: True
2D Convolution: True
