In [1]:
import numpy

$$\begin{align*}
S^{T} x
&= \begin{bmatrix}
5 & 10 \\
17 & 19
\end{bmatrix} \otimes \begin{bmatrix}
1 & 0 \\
0 & 0
\end{bmatrix} \\
&= \begin{bmatrix}
5 & 0 & 10 & 0 \\
0 & 0 & 0 & 0 \\
17 & 0 & 19 & 0 \\
0 & 0 & 0 & 0
\end{bmatrix}
\end{align*}$$

In [2]:
numpy.kron(
    numpy.array(
        [
            [5, 10],
            [17, 19]
        ]
    ), numpy.array([[1, 0], [0, 0]]))

array([[ 5,  0, 10,  0],
       [ 0,  0,  0,  0],
       [17,  0, 19,  0],
       [ 0,  0,  0,  0]])

$$\begin{align*}
FFT(S^{T} x)
&= FFT(\begin{bmatrix}
5 & 10 \\
17 & 19
\end{bmatrix} \otimes \begin{bmatrix}
1 & 0 \\
0 & 0
\end{bmatrix}) \\
&= FFT(\begin{bmatrix}
1 & 0 \\
0 & 0
\end{bmatrix}) \otimes FFT(\begin{bmatrix}
5 & 10 \\
17 & 19
\end{bmatrix}) \\
&= \begin{bmatrix}
1 & 1 \\
1 & 1
\end{bmatrix} \otimes FFT(\begin{bmatrix}
5 & 10 \\
17 & 19
\end{bmatrix})  \\
\end{align*}$$

In [3]:
numpy.fft.fft2(
    numpy.kron(
        numpy.array(
            [
                [5, 10],
                [17, 19]
            ]
        ), 
        numpy.array([[1, 0], [0, 0]])
    )
)

array([[ 51.+0.j,  -7.+0.j,  51.+0.j,  -7.+0.j],
       [-21.+0.j,  -3.+0.j, -21.+0.j,  -3.+0.j],
       [ 51.+0.j,  -7.+0.j,  51.+0.j,  -7.+0.j],
       [-21.+0.j,  -3.+0.j, -21.+0.j,  -3.+0.j]])

In [4]:
numpy.kron(
    numpy.array([[1, 1], [1, 1]]),
    numpy.fft.fft2(
        numpy.array(
            [
                [5, 10],
                [17, 19]
            ]
        )
    )
)

array([[ 51.+0.j,  -7.+0.j,  51.+0.j,  -7.+0.j],
       [-21.+0.j,  -3.+0.j, -21.+0.j,  -3.+0.j],
       [ 51.+0.j,  -7.+0.j,  51.+0.j,  -7.+0.j],
       [-21.+0.j,  -3.+0.j, -21.+0.j,  -3.+0.j]])

In [5]:
numpy.fft.fft2( numpy.array([[1, 0], [0, 0]]))

array([[1.+0.j, 1.+0.j],
       [1.+0.j, 1.+0.j]])

In [6]:
numpy.fft.fft2( numpy.array( [ [5, 10], [17, 19] ] ))

array([[ 51.+0.j,  -7.+0.j],
       [-21.+0.j,  -3.+0.j]])

In [7]:
STS = numpy.diag(numpy.array([1, 0, 1, 0]))
STS

array([[1, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 0]])

In [8]:
import sys
sys.path.append('..')
import lasp.utils
lasp.utils.fourier_diagonalization(STS, shape_out=(4, 4)) / 2

array([[1.+0.j, 0.+0.j, 1.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 1.+0.j],
       [1.+0.j, 0.+0.j, 1.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 1.+0.j]])

In [9]:
Jd = numpy.ones(shape=(2,2))
id = numpy.identity(n=2)

In [10]:
Jd

array([[1., 1.],
       [1., 1.]])

In [11]:
id

array([[1., 0.],
       [0., 1.]])

In [12]:
numpy.kron(Jd, id)

array([[1., 0., 1., 0.],
       [0., 1., 0., 1.],
       [1., 0., 1., 0.],
       [0., 1., 0., 1.]])

In [13]:
_1d = numpy.ones(shape=(2, 1))
_1d

array([[1.],
       [1.]])

In [14]:
numpy.kron(_1d, id) @ numpy.kron(_1d.T, id)

array([[1., 0., 1., 0.],
       [0., 1., 0., 1.],
       [1., 0., 1., 0.],
       [0., 1., 0., 1.]])

In [None]:
def f(
    y: numpy.ndarray,
    kernel: numpy.ndarray,
    d: int
) -> numpy.ndarray:
    
    nb_rows, nb_cols = y.shape

    h_diag = lasp.utils.fourier_diagonalization(
        kernel = kernel,
        shape_out = numpy.array([nb_rows, nb_cols])
    )

    
    _1_d = numpy.ones(shape=(d, 1))
    _i_d = numpy.identity(n = d)

    