<a href="https://colab.research.google.com/github/Al-Kindi-0/Al-Kindi-0/blob/main/circulant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sympy
from scipy.linalg import circulant
from scipy.fft import fft, ifft
import numpy as np

ModuleNotFoundError: No module named 'scipy.fft'

In [2]:
print(circulant([1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]).T)

[[    1     1     2     1     8    32     2   256  4096     8 65536  1024]
 [ 1024     1     1     2     1     8    32     2   256  4096     8 65536]
 [65536  1024     1     1     2     1     8    32     2   256  4096     8]
 [    8 65536  1024     1     1     2     1     8    32     2   256  4096]
 [ 4096     8 65536  1024     1     1     2     1     8    32     2   256]
 [  256  4096     8 65536  1024     1     1     2     1     8    32     2]
 [    2   256  4096     8 65536  1024     1     1     2     1     8    32]
 [   32     2   256  4096     8 65536  1024     1     1     2     1     8]
 [    8    32     2   256  4096     8 65536  1024     1     1     2     1]
 [    1     8    32     2   256  4096     8 65536  1024     1     1     2]
 [    2     1     8    32     2   256  4096     8 65536  1024     1     1]
 [    1     2     1     8    32     2   256  4096     8 65536  1024     1]]


In [3]:
M = circulant([1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]).T

In [4]:
v = np.array([i for i in range(12)])

In [5]:
# We need the characterizing vector of the circulant matrix i.e. the first column
zz = [1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]
zz = zz[1:]
zz.reverse()
zz = [1] + zz
zz

[1, 1024, 65536, 8, 4096, 256, 2, 32, 8, 1, 2, 1]

Check that the fourier-based multiplication works


In [6]:
z = fft(zz)
v_hat = fft(v)
tmp = np.multiply(v_hat,z)
ifft(tmp)

array([701468.+0.j, 760147.+0.j,  44682.+0.j, 115553.+0.j, 137368.+0.j,
       205263.+0.j, 276206.+0.j, 346789.+0.j, 417660.+0.j, 488615.+0.j,
       559558.+0.j, 630513.-0.j])

In [7]:
M.dot(v)

array([701468, 760147,  44682, 115553, 137368, 205263, 276206, 346789,
       417660, 488615, 559558, 630513])

Yes it does!

We can also flip the order of the operators and use the first row of M instead of first column then. This should give the same result.

In [8]:
cc = np.multiply(fft([1, 1, 2, 1, 8, 32, 2, 256, 4096, 8, 65536, 1024]), ifft(v))
fft(cc)

array([701468.+0.j, 760147.+0.j,  44682.+0.j, 115553.+0.j, 137368.+0.j,
       205263.-0.j, 276206.+0.j, 346789.+0.j, 417660.+0.j, 488615.+0.j,
       559558.+0.j, 630513.+0.j])

**Dissection of the 12 point FFT as 4x3 split radix FFT using matrix formulation**

TODO: Write the matrix formulation of the 4x3 split radix FFT and list the reference.

DONE: See below!

In [9]:
def FFT4(x):
  # We use the scipy.fft for FFT4 as it is not the point here to implement it
  return fft(x)
def FFT12(x):

  # See page 81 in "Computational frameworks for the fast Fourier transform" (Charles Van Loan)
  N = len(x)
  assert(N == 12)
  theta = np.linspace(0,2*np.pi,N+1)
  theta = theta[:-1]
  zeta = np.exp(-1j*theta)
  Nthird = N//3
  assert(Nthird == 4)
  I = np.eye(Nthird)
  omega = np.exp(-2*np.pi*1j/3)
  F_3 = np.zeros([3,3],dtype= np.complex64)
  for i in range(3):
    for j in range(3):
      F_3[i,j] = omega**(i*j)

  
  tmp = [FFT4(x[0::3]),
                                zeta[0:Nthird]* FFT4(x[1::3]),
                                zeta[0:Nthird]**2 * FFT4(x[2::3])]
  tmp = np.concatenate(tmp)

  v = np.kron(F_3,np.eye(Nthird)) @ tmp # There are 4 length 3 vectors resulting from FFT3. These 4 vectors are multiplied with there counterparts in the representation
                                        # of the MDS matrix in the frequency domain. Check further down for how we will compute the component-wise multiplication in the frequency
                                        # domain while avoiding FFTs of length 3.

  return v

Check that our split-radix implementation works.

In [10]:
np.isclose(fft(v),FFT12(v))

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True])

It does!

**Symblic computations in order to compute the state intermediate between the 3xFFT4 and 3xiFFT4**

W

This is the inverse from inverting the matrix representation above.

In [11]:
def iFFT3(x):
  return ifft(x)
def iFFT12(x):

  N = len(x)
  assert(N == 12)
  theta = np.linspace(0,2*np.pi,N+1)
  theta = theta[:-1]
  #theta[:] = theta[::-1]
  zeta = np.exp(1j*theta)
  Nfourth = N//4
  assert(Nfourth == 3)
  I = np.eye(Nfourth)
  omega = np.exp(2*np.pi*1j/4)
  F_4 = np.zeros([4,4],dtype= np.complex64)
  for i in range(4):
    for j in range(4):
      F_4[i,j] = omega**(i*j)/4

  # After multiplying the 4 length 3 vectors of the representation of x in the frequency domain with their counterpart for MDS, we transform back:
  tmp = [iFFT3(x[0::4]), # x[0::4] is the result of the component-wise multiplication of the first 2 length 3 vectors
                                zeta[0:Nfourth]* iFFT3(x[1::4]),  # This is the result of the second multiplication
                                zeta[0:Nfourth]**2 * iFFT3(x[2::4]), # Third MDS multiplication (should be real)
                                zeta[0:Nfourth]**3 * iFFT3(x[3::4])] # This is just the conjugate of the second MDS multiplication
  tmp = np.concatenate(tmp)

  v = np.kron(F_4,np.eye(Nfourth)) @ tmp

  return v

Sanity check that our forward and backward transforms are correct:

In [12]:
iFFT12(FFT12(v))

array([ 0.        +0.00000000e+00j,  1.00000002+4.35071977e-08j,
        2.00000002-2.01917595e-08j,  3.        +1.83697008e-16j,
        4.00000002-3.12367547e-09j,  4.99999999-2.01917602e-08j,
        6.        -3.67394016e-16j,  7.00000001-2.01917612e-08j,
        7.99999998-3.12367691e-09j,  9.        -1.65327317e-15j,
        9.99999998-2.01917629e-08j, 10.99999998+4.35071947e-08j])

Indeed!

***Finding expressions for the intermediate computations between the 3 FFT4 and 3 iFFT4***

In order to avoid the FFT3 and iFFT3 we perform the following manoeuver.
We know from the above, for the 3-4 split radix FFT12, that after computing the 3 FFT4s, the rest of the "Circulant matrix multiplied by a vector" story goes as follows:

Take the length 3 length 4 vectors and look at them as 4 length 3 vectors with the appropriate indexing -> multiply by linking factors i.e. twiddles -> apply FFT3 -> Multiply with the second similarly transformed vector representing the appropriate portion of the MDS matrix in frequency domain -> apply iFFT3 -> multiply by reverse-linking factors i.e. inverse twiddles -> Represent the 4 length 3 vectors as 3 length 4 vectors -> apply 3 iFFT4.

In our context, see the reference above, we have:

* Forward: $\zeta = [exp(-i2j\pi/12) \text{for j in range(N)}]$ 
i.e. $\omega = exp(-i2\pi/12)$
and $D := diag(I_4,\Omega_{3,4}, \Omega_{3,4}^2)$ equals $diag(1,1,1,1,1,\omega^1,\omega^2,\omega^3,1,\omega^2,\omega^4,\omega^6)$
The scaling factors we multiply with are as follows then:
1. $D[0::3] = (1,1,1)$
2. $D[1::3] = (1,\omega,\omega^2)$
3. $D[2::3] = (1,\omega^2,\omega^4)$
4. $D[3::3] = (1,\omega^3,\omega^6)$

* Backward linking factors: let $\omega' = exp(i2\pi/12)$, then the linking factors are as follows:
1. $(1,1,1)$
2. $(1,\omega',\omega'^2)$
3. $(1,\omega'^2,\omega'^4)$
4. $(1,\omega'^3,\omega'^6)$

So for example, for the second block of multiplications, if we want to convolve $x :=[x_0,x_1,x_2]$ and $y:= [y_0,y_1,y_2]$, we do, for $d := (1,\omega,\omega^2)$ and $d' := (1,\omega',\omega'^2)$, the following:\
$d' * iFFT3(FFT3(d * x) * FFT3(d * y))$
where $*$ stands here for the Hadamard product.
In what follows we implement the following trick: compute $d' * iFFT3(FFT3(d * x) * FFT3(d * y))$ symbolically in $x,y$ so as to avoid FFT3 and iFFT3 altogether.


In [13]:
from numpy.lib.type_check import real
from sympy import re, im, I, E, symbols, simplify, cancel

x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11= symbols('x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11', real=True)
y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11= symbols('y0 y1 y2 y3 y4 y5 y6 y7 y8 y9 y10 y11', real=True)
f0, f1, g0, g1 = symbols('f0 f1 g0 g1', real= True)

In [14]:
X = [x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11]
Y = [y0,y1,y2,y3,y4,y5,y6,y7,y8,y9,y10,y11]

**FFT3 and iFFT3 symbolically**


In [15]:
from sympy.matrices import Matrix, eye, zeros, ones, diag, GramSchmidt

From: Some FFT Algorithms for Small-Length Real-Valued Sequences
https://www.mdpi.com/2076-3417/12/9/4700/htm

In [16]:
A3_hat= Matrix(([1,0,0],[0,1,1],[0,1,-1]))
A3_tild= Matrix(([1,1,0],[0,1,0],[0,0,1]))
D3= Matrix(([1,0,0],[0,g0,0],[0,0,g1])) # g0 == -3/2 & g1 == np.sqrt(3)/2
C3= Matrix(([1,0,0],[1,1,0],[0,0,1]))

In [17]:
Forw3 = C3 * D3 * A3_tild * A3_hat
Forw3

Matrix([
[1,      1,      1],
[1, g0 + 1, g0 + 1],
[0,     g1,    -g1]])

In [18]:
Back3 = Forw3**(-1)
Back3

Matrix([
[ 1 + 1/g0,    -1/g0,         0],
[-1/(2*g0), 1/(2*g0),  1/(2*g1)],
[-1/(2*g0), 1/(2*g0), -1/(2*g1)]])

In [19]:
def fft3s(x):
  [x0,x1,x2] = x
  y0 = x0 + x1 + x2
  y1 = x0 + (1 + g0)*(x1 + x2)
  y2 = g1*(x1-x2)
  return [y0,y1 - y2*I,y1 + I*y2]

def ifft3s(y):
  [z0,z1,z2] = y
  y0 = z0
  y1 = (z1 + z2)*0.5
  y2 = (z2 - z1)/(2*I)

  x0 = y0 + (y0 - y1)/g0
  x1 = ((y2/(2*g1)) - (y0/(2*g0)) + (y1/(2*g0)))
  x2 = ((-y2/(2*g1)) - (y0/(2*g0)) + (y1/(2*g0)))
  return [x0,x1,x2]

In [20]:
omega = np.exp(-2*1j*np.pi/12)

In [21]:
omega_prime = np.exp(2*1j*np.pi/12)

**First block**

In [22]:
tmpo = [a*b for (a,b) in zip(fft3s([x0,x1,x2]),fft3s([y0,y1,y2]))]

In [23]:
expr = [0]*3
for i in range(3):
  expr[i] = (ifft3s(tmpo)[i].simplify().subs({g0:-3/2, g1: np.sqrt(3)/2})).simplify().nsimplify(tolerance = 1e-14)


In [24]:
expr[0]

x0*y0 + x1*y2 + x2*y1

In [25]:
expr[1]

x0*y1 + x1*y0 + x2*y2

In [26]:
expr[2]

x0*y2 + x1*y1 + x2*y0

In [27]:
def block1(x,y):
  [x0,x1,x2] = x
  [y0,y1,y2] = y
  z0 = x0*y0 + x1*y2 + x2*y1
  z1 = x0*y1 + x1*y0 + x2*y2
  z2 = x0*y2 + x1*y1 + x2*y0
  return [z0, z1, z2]

**Second block**

In [28]:
x_scaled  =[(x0+I*x3),(x1+I*x4)*omega,(x2+I*x5)*omega**2]
y_scaled  =[(y0+I*y3),(y1+I*y4)*omega,(y2+I*y5)*omega**2]

In [29]:
x_scaled  =[(x0),(x1)*omega,(x2)*omega**2]
y_scaled  =[(y0),(y1)*omega,(y2)*omega**2]

In [30]:
tmpo2 = [a*b for (a,b) in zip(fft3s(x_scaled),fft3s(y_scaled))]

In [31]:
expr = [0]*3
for i in range(3):
  expr[i] = (omega_prime**i * ifft3s(tmpo2)[i].simplify().subs({g0:-3/2, g1: np.sqrt(3)/2})).simplify().nsimplify(tolerance = 1e-14)

In [32]:
expr[0]

x0*y0 - I*x1*y2 - I*x2*y1

In [33]:
expr[1]

x0*y1 + x1*y0 - I*x2*y2

In [34]:
expr[2]

x0*y2 + x1*y1 + x2*y0

In [35]:
def block2(x,y):
  [(x0r,x0i),(x1r,x1i),(x2r,x2i)] = x
  [(y0r,y0i),(y1r,y1i),(y2r,y2i)] = y
  x0s = x0r + x0i
  x1s = x1r + x1i
  x2s = x2r + x2i
  y0s = y0r + y0i
  y1s = y1r + y1i
  y2s = y2r + y2i

  # Compute x0​y0 ​− ix1​y2 ​− ix2​y1​ using Karatsuba
  m0 = (x0r*y0r , x0i*y0i)
  m1 = (x1r*y2r , x1i*y2i)
  m2 = (x2r*y1r , x2i*y1i)
  #print("y0s, y1s, y2s", y0s, y1s, y2s)
  z0r = (m0[0] - m0[1])           + (x1s*y2s - m1[0] - m1[1]) + (x2s*y1s - m2[0] - m2[1])
  z0i = (x0s*y0s - m0[0] - m0[1]) + (-m1[0] + m1[1])           + (-m2[0] + m2[1]) 
  z0 = (z0r, z0i)

  # Compute x0​y1​ + x1​y0 ​− ix2​y2 using Karatsuba
  m0 = (x0r*y1r , x0i*y1i)
  m1 = (x1r*y0r , x1i*y0i)
  m2 = (x2r*y2r , x2i*y2i)
  #print("y0s, y1s, y2s", y0s, y1s, y2s)
  z1r = (m0[0] - m0[1])           + (m1[0] - m1[1])           + (x2s*y2s - m2[0] - m2[1]) 
  z1i = (x0s*y1s - m0[0] - m0[1]) + (x1s*y0s - m1[0] - m1[1]) + (-m2[0] + m2[1])
  z1 = (z1r, z1i)

  # Compute x0​y2​ + x1​y1 ​+ x2​y0​ using Karatsuba
  m0 = (x0r*y2r , x0i*y2i)
  m1 = (x1r*y1r , x1i*y1i)
  m2 = (x2r*y0r , x2i*y0i)
  #print("y0s, y1s, y2s", y0s, y1s, y2s)
  z2r = (m0[0] - m0[1])           + (m1[0] - m1[1])           + (m2[0] - m2[1])
  z2i = (x0s*y2s - m0[0] - m0[1]) + (x1s*y1s - m1[0] - m1[1]) + (x2s*y0s - m2[0] - m2[1])
  z2 = (z2r, z2i)
  

  return [z0, z1, z2]

**Third block**

In [36]:
x_scaled  =[x0,x1*omega**2,x2*omega**4]
y_scaled  =[y0,y1*omega**2,y2*omega**4]
tmpo3 = [a*b for (a,b) in zip(fft3s(x_scaled),fft3s(y_scaled))]
for i in range(3):
  expr[i] = (omega_prime**(2*i) * ifft3s(tmpo3)[i].simplify().subs({g0:-3/2, g1: np.sqrt(3)/2})).simplify().nsimplify(tolerance = 1e-14)


In [37]:
expr[0]

x0*y0 - x1*y2 - x2*y1

In [38]:
expr[1]

x0*y1 + x1*y0 - x2*y2

In [39]:
expr[2]

x0*y2 + x1*y1 + x2*y0

In [40]:
def block3(x,y):
  [x0,x1,x2] = x
  [y0,y1,y2] = y
  z0 = x0*y0 - x1*y2 - x2*y1
  z1 = x0*y1 + x1*y0 - x2*y2
  z2 = x0*y2 + x1*y1 + x2*y0
  return [z0, z1, z2]

**Forth block (should be the conjugate of the second block)**

In [41]:
x_scaled  =[x0,x1*omega**3,x2*omega**6]
y_scaled  =[y0,y1*omega**3,y2*omega**6]
tmpo4 = [a*b for (a,b) in zip(fft3s(x_scaled),fft3s(y_scaled))]
for i in range(3):
  expr[i] = (omega_prime**(3*i) * ifft3s(tmpo4)[i].simplify().subs({g0:-3/2, g1: np.sqrt(3)/2})).simplify().nsimplify(tolerance = 1e-14)

In [42]:
expr[0]

x0*y0 + I*x1*y2 + I*x2*y1

In [43]:
expr[1]

x0*y1 + x1*y0 + I*x2*y2

In [44]:
expr[2]

x0*y2 + x1*y1 + x2*y0

Indeed!

**Actual implementation of multiplication of a circular matrix by a vector in the frequency domain in O(nlog(n))**

In [45]:
def fft2(x):
  x0 = x[0]
  x1 = x[1]
  return [x0 + x1, x0 - x1]

def ifft2(x):
  x0 = x[0]
  x1 = x[1]
  return [(x0 + x1) >> 1 , (x0 - x1) >> 1]

def fft4(x):
  z = [0,0,0,0]
  z[0], z[2] = fft2([x[0],x[2]])
  z[1], z[3] = fft2([x[1],x[3]])
  y0 = z[0] + z[1]
  y1 = (z[2],-z[3])
  y2 = z[0] - z[1]
  return (y0, y1, y2)

def ifft4(y):
  z = [0,0,0,0]
  z[0] = (y[0] + y[2]) >> 1
  z[1] = (y[0] - y[2]) >> 1
  z[2] = y[1][0]
  z[3] = -y[1][1]
  x = [0,0,0,0]
  [x[0],x[2]] = ifft2([z[0],z[2]])
  [x[1],x[3]] = ifft2([z[1],z[3]])
  return x


In [46]:
ifft4(fft4([34,5,-1,4]))

[34, 5, -1, 4]

In [47]:
def fft4x3(x):
  [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11] = x
  (y0, y1, y2) = fft4([x0, x3, x6, x9])   # y3 is conjugate of y1
  (y4, y5, y6) = fft4([x1, x4, x7, x10])  # y7 is conjugate of y5
  (y8, y9, y10) = fft4([x2, x5, x8, x11]) # y11 is conjugate of y9
  return (
      [y0, y4, y8],
      [y1, y5, y9],
      [y2, y6, y10]
  )

In [48]:
def mds_frequency(x):
  return fft4x3(x)

In [49]:
def block1(x,y):
  [x0,x1,x2] = x
  [y0,y1,y2] = y
  z0 = x0*y0 + x1*y2 + x2*y1
  z1 = x0*y1 + x1*y0 + x2*y2
  z2 = x0*y2 + x1*y1 + x2*y0
  return [z0, z1, z2]


def block2(x,y):
  [(x0r,x0i),(x1r,x1i),(x2r,x2i)] = x
  [(y0r,y0i),(y1r,y1i),(y2r,y2i)] = y
  x0s = x0r + x0i
  x1s = x1r + x1i
  x2s = x2r + x2i
  y0s = y0r + y0i
  y1s = y1r + y1i
  y2s = y2r + y2i

  # Compute x0​y0​−ix1​y2​−ix2​y1​ using Karatsuba
  m0 = (x0r*y0r , x0i*y0i)
  m1 = (x1r*y2r , x1i*y2i)
  m2 = (x2r*y1r , x2i*y1i)
  z0r = (m0[0] - m0[1])           + (x1s*y2s - m1[0] - m1[1]) + (x2s*y1s - m2[0] - m2[1])
  z0i = (x0s*y0s - m0[0] - m0[1]) + (-m1[0] + m1[1])           + (-m2[0] + m2[1]) 
  z0 = (z0r, z0i)

  #print("z0 is (correct) {?}", z0)

  # Compute x0​y1​+x1​y0​−ix2​y2 using Karatsuba
  m0 = (x0r*y1r , x0i*y1i)
  m1 = (x1r*y0r , x1i*y0i)
  m2 = (x2r*y2r , x2i*y2i)
  z1r = (m0[0] - m0[1])           + (m1[0] - m1[1])           + (x2s*y2s - m2[0] - m2[1]) 
  z1i = (x0s*y1s - m0[0] - m0[1]) + (x1s*y0s - m1[0] - m1[1]) + (-m2[0] + m2[1])
  z1 = (z1r, z1i)

  # Compute x0​y2​+x1​y1​+x2​y0​ using Karatsuba
  m0 = (x0r*y2r , x0i*y2i)
  m1 = (x1r*y1r , x1i*y1i)
  m2 = (x2r*y0r , x2i*y0i)
  z2r = (m0[0] - m0[1])           + (m1[0] - m1[1])           + (m2[0] - m2[1])
  z2i = (x0s*y2s - m0[0] - m0[1]) + (x1s*y1s - m1[0] - m1[1]) + (x2s*y0s - m2[0] - m2[1])
  z2 = (z2r, z2i)


  return [z0, z1, z2]


def block3(x,y):
  [x0,x1,x2] = x
  [y0,y1,y2] = y
  z0 = x0*y0 - x1*y2 - x2*y1
  z1 = x0*y1 + x1*y0 - x2*y2
  z2 = x0*y2 + x1*y1 + x2*y0
  return [z0, z1, z2]

In [50]:
def MDS_multiply_freq(state):
  [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state

  (u0, u1, u2) = fft4([s0, s3, s6, s9])
  (u4, u5, u6) = fft4([s1, s4, s7, s10])
  (u8, u9, u10) = fft4([s2, s5, s8, s11])

  [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE)
  [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO)
  [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE)

  [s0, s3, s6, s9] = ifft4((v0, v1, v2))
  [s1, s4, s7, s10] = ifft4((v4, v5, v6))
  [s2, s5, s8, s11] = ifft4((v8, v9, v10))

  return [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]

In [51]:
MDS_FREQ_BLOCK_ONE = mds_frequency([1, 1024, 65536, 8, 4096, 256, 2, 32, 8, 1, 2, 1])[0]
MDS_FREQ_BLOCK_TWO = mds_frequency([1, 1024, 65536, 8, 4096, 256, 2, 32, 8, 1, 2, 1])[1]
MDS_FREQ_BLOCK_THREE = mds_frequency([1, 1024, 65536, 8, 4096, 256, 2, 32, 8, 1, 2, 1])[2]

In [52]:
MDS_FREQ_BLOCK_THREE

[-6, -3042, 65287]

In [53]:
MDS_multiply_freq(v**12) - M.dot(v**12)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

Yay!