# Why does the Circulant Convolution Theorem not hold as I expect?

In [1]:
import numpy as np
from scipy.fft import fft, ifft
from scipy.signal import convolve

## Example of how it does not hold

In [2]:
a = np.array([[1, 2, 3]]).astype(np.float32)
A = np.array([ # Circulant matrix
    [1, 3, 2],
    [2, 1, 3],
    [3, 2, 1]
]).astype(np.float32)

b = np.array([[4.0, 5.0, 6.0]]).astype(np.float32)

# Matrix-vector multiplication
A @ b.T

array([[31.],
       [31.],
       [28.]], dtype=float32)

In [3]:
# FFT-multiplication
a_fft = fft(a, axis=-1)
b_fft = fft(b, axis=-1)
c_fft = a_fft * b_fft
ifft(c_fft, axis=-1)

array([[31.+0.j, 31.+0.j, 28.-0.j]], dtype=complex64)

# Answer: Axis=0 is important

In [4]:
a.shape

(1, 3)

In [5]:
a_fft.shape

(1, 3)

In [6]:
fft(a).shape

(1, 3)

In [7]:
fft(a, axis=0).shape

(1, 3)

In [8]:
fft(a[:, 0])

array([1.-0.j], dtype=complex64)

In [9]:
fft(a, axis=0)

array([[1.-0.j, 2.-0.j, 3.-0.j]], dtype=complex64)

# Larger shape example

In [18]:
def make_circulant(a):
    n = a.shape[-1] # event shape
    new_a = np.zeros(a.shape + (a.shape[-1],))
    for i_posterior_sample in range(a.shape[0]):
        for i_data_sample in range(a.shape[1]):
            vector = a[i_posterior_sample, i_data_sample]
            new_a[i_posterior_sample, i_data_sample] = np.array([
                np.roll(vector, i)
                for i in range(n)
            ]).T
    return new_a

batch_shape = (100, 3)
event_shape = (4,)
total_shape = batch_shape + event_shape
a = np.random.randn(*total_shape)
A = make_circulant(a)
b = np.random.randn(2,4) # two points of 4 features

In [19]:
a.shape, A.shape, b.shape, (b @ A).shape
# Before, fft was in event shape 

((100, 3, 4), (100, 3, 4, 4), (2, 4), (100, 3, 2, 4))

In [65]:
(b @ A).shape


(100, 3, 2, 4)

In [13]:
a_fft = fft(a, axis=-1)
b_fft = fft(b, axis=-1)
r = (ifft(np.einsum("ad,bcd->bcad", b_fft, a_fft), axis=-1))[8]
if np.allclose(np.imag(r), 0):
    print(np.real(r))
else:
    print(r)

if not np.allclose(r, A @ b):
    print("Not equal")

[[[-0.83285108  5.64374111 -4.60528807 -0.51174403]
  [-1.83441373  0.27650238  2.09021669 -3.05863214]]

 [[ 1.92100202 -2.40551611 -1.52590973  0.64651272]
  [-1.70363484 -1.61176604 -4.23350825 -3.70627455]]

 [[ 0.61268049 -4.333021    4.18417123 -0.94828887]
  [-0.726772   -1.7199233  -2.72040372  1.16928327]]]


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 4)

In [17]:
first_a = a[0, 0]
ifft(fft(a) * fft(b))


ValueError: operands could not be broadcast together with shapes (100,3,4) (2,4) 

In [14]:
a[0, 0], A[0, 0]

(array([-0.28392885,  0.17000811,  1.23580274,  0.0095501 ]),
 array([[-0.28392885,  0.0095501 ,  1.23580274,  0.17000811],
        [ 0.17000811, -0.28392885,  0.0095501 ,  1.23580274],
        [ 1.23580274,  0.17000811, -0.28392885,  0.0095501 ],
        [ 0.0095501 ,  1.23580274,  0.17000811, -0.28392885]]))

In [156]:
def expand_circ_mult(w,x): # w has (num_circ, D_X), x has (N, D_X)
    x_fft = ifft(x)
    x_fft = np.repeat(x_fft[:, None, :], w.shape[0], axis=1)
    return np.real(fft(fft(w) * x_fft)).reshape(x.shape[0], -1) # (N, num_circ * D_X)

def expand_circ_mult_2(w,x): # w has (num_post, num_circ, D_X), x has (N, D_X)
    return np.real(ifft(fft(x, axis=-1)[..., None, None, :] * fft(w, axis=-1), axis=-1)).transpose(1, 2,3,0)
truth = np.array([[[b @ A_circ for A_circ in data_sample] for data_sample in post_sample] for post_sample in A])
print("True", truth.transpose(0,3,1,2).reshape(100, 2, -1))
print("FFT-based", expand_circ_mult_2(a, b).reshape(100, 2, -1))

True [[[ 0.52363488 -1.22758976 -1.52798296 ...  0.58603447  1.87662215
   -0.91758101]
  [-0.73852423  3.48728588  0.00640592 ... -0.7024551   1.11640747
   -0.49100923]]

 [[-0.19353404  1.03899735  0.25784952 ... -0.52844942  0.68942005
    0.52613533]
  [-0.12143285 -0.05526067 -0.42424787 ...  1.35173033 -2.13052846
    0.95097287]]

 [[-2.60520184 -0.86420335  1.38326298 ... -1.44966662 -1.20475083
    1.79842337]
  [ 3.81754251 -1.84309725  2.55036778 ... -1.63255769  3.28272146
   -2.69400548]]

 ...

 [[ 2.00688185 -1.35599455 -1.685945   ...  1.82425295 -1.83505736
   -2.44084453]
  [-0.39432464 -1.07327163  1.14458065 ... -0.59898019  1.6397398
    1.0854409 ]]

 [[ 0.30695137  2.55152962  1.11180408 ...  0.15822114 -1.40297139
   -0.44882345]
  [-1.70539472 -1.09572393 -2.26190555 ... -1.60104401  3.64273383
   -1.31578524]]

 [[-1.41233115 -0.8736664   1.43804932 ...  0.66559273 -1.99580553
   -0.51885394]
  [-0.21034882  1.46260396 -1.13345477 ... -1.17227253  1.71510204


In [None]:
A.shape
# (num_post, num_circ, D_X, D_X)

(100, 3, 4, 4)

In [89]:
(b @ A)[0]

array([[[ 0.1318192 , -0.24024359, -1.13616728, -1.88626097],
        [-0.53544627,  3.16496091, -0.19667205,  3.828675  ]],

       [[ 0.79125479, -0.45769794, -0.80911284,  0.20529651],
        [-0.57248073,  0.60305945,  0.03534716,  0.47457689]],

       [[-1.79361821, -2.37406887,  0.70194307,  2.04252232],
        [ 2.31457905,  0.26388806,  1.72524334, -1.4573524 ]]])

In [130]:
# First circ matrix of A
print((A[0, :] @ b.T))
print(np.real(ifft(fft(b, axis=-1)[..., None, :] * fft(a[0], axis=-1), axis=-1)).transpose(1,2,0))

[[[ 0.52363488 -0.73852423]
  [-1.22758976  3.48728588]
  [-1.52798296  0.00640592]
  [-0.8989148   3.50635003]]

 [[ 0.34621344 -0.34181591]
  [ 0.66377294  0.23694867]
  [-0.36407149 -0.19531766]
  [-0.91617437  0.84068766]]

 [[-2.96829729  2.92341492]
  [ 0.58603447 -0.7024551 ]
  [ 1.87662215  1.11640747]
  [-0.91758101 -0.49100923]]]
[[[ 0.52363488 -0.73852423]
  [-1.22758976  3.48728588]
  [-1.52798296  0.00640592]
  [-0.8989148   3.50635003]]

 [[ 0.34621344 -0.34181591]
  [ 0.66377294  0.23694867]
  [-0.36407149 -0.19531766]
  [-0.91617437  0.84068766]]

 [[-2.96829729  2.92341492]
  [ 0.58603447 -0.7024551 ]
  [ 1.87662215  1.11640747]
  [-0.91758101 -0.49100923]]]


In [None]:
print((A[:, :] @ b.T)[1])
print(np.real(ifft(fft(b, axis=-1)[..., None, None, :] * fft(a, axis=-1), axis=-1)).transpose(1, 2,3,0))

[[[-0.19353404 -0.12143285]
  [ 1.03899735 -0.05526067]
  [ 0.25784952 -0.42424787]
  [-1.14543758  0.68518839]]

 [[ 0.65099134  0.74912699]
  [ 1.26007233  0.17605406]
  [-1.62438439  1.32497681]
  [-2.18131337  1.53899674]]

 [[ 0.24917472 -2.04467996]
  [-0.52844942  1.35173033]
  [ 0.68942005 -2.13052846]
  [ 0.52613533  0.95097287]]]
[[[-0.19353404 -0.12143285]
  [ 1.03899735 -0.05526067]
  [ 0.25784952 -0.42424787]
  [-1.14543758  0.68518839]]

 [[ 0.65099134  0.74912699]
  [ 1.26007233  0.17605406]
  [-1.62438439  1.32497681]
  [-2.18131337  1.53899674]]

 [[ 0.24917472 -2.04467996]
  [-0.52844942  1.35173033]
  [ 0.68942005 -2.13052846]
  [ 0.52613533  0.95097287]]]


In [None]:
print(A[0,0], a[0, 0])

[[-0.72894044  0.33606784 -1.79386313 -0.29740967]
 [-0.29740967 -0.72894044  0.33606784 -1.79386313]
 [-1.79386313 -0.29740967 -0.72894044  0.33606784]
 [ 0.33606784 -1.79386313 -0.29740967 -0.72894044]] [-0.72894044 -0.29740967 -1.79386313  0.33606784]


In [103]:
a[0,0].reshape(-1,1)

array([[-0.72894044],
       [-0.29740967],
       [-1.79386313],
       [ 0.33606784]])

In [117]:
A.shape

(100, 3, 4, 4)