# Kronecker product


1. $W \in R^{m \times n}$

2. $X \in R^{m_1 \times n_1}$

3. $Z \in R^{m2 \times n_2 }$

3. $m_2= m \ m_1$ and $n_2 = n \ n_1$

For any matrix $X \in R^{m \times n}$ and $Y \in R^{p \times q}$, the Kronecker product $X \otimes Y$ is a block matrix:

\begin{align*}
W=X \otimes Y = \begin{bmatrix} 
x_{11}Y & \dots  & x_{1n}Y \\
\vdots & \ddots & \vdots \\
x_{m1}Y & \dots  & x_{mn}Y 
\end{bmatrix} \in R^{mp \times nq},
\end{align*}

where $x_{ij}$ is the element of $X$ at its $i^{\text{th}}$ row and $j^{\text{th}}$ column.

In [16]:
import torch

m, n = 6,4


m1, n1 = 3, 2
m2, n2 = m//m1, n//n1

A=torch.randn(m1,n1)
B=torch.randn(m2,n2)
W=torch.kron(A,B)
W.shape

torch.Size([6, 4])

# Implicit Linear Transformation via KP

In [51]:
def V(x):
    return x.flatten()

def R(x,n2,n1):
    return x.reshape(n2,n1)
def implicit_kp_liearn(A,B,x):
    _,n2 = B.shape
    _,n1 = A.shape    
    return V(B @ R(x,n2,n1)@A.T)


In [53]:
x=torch.randn(n)*1.
print(x,x.dtype)

print(W@x)

print(implicit_kp_liearn(A,B,x))

tensor([ 2.0599, -0.9393, -0.6693,  1.4027]) torch.float32
tensor([ 0.3651, -1.0217, -2.1914, 11.6564,  0.3861, -3.7857])
tensor([ 0.3762, -1.7712,  0.1594, -1.1943, 10.9832, -3.1313])


In [55]:
x=torch.ones(n)* 3.1
print(x,x.dtype)

print(W@x)

print(implicit_kp_liearn(A,B,x))

tensor([3.1000, 3.1000, 3.1000, 3.1000]) torch.float32
tensor([ 0.7731,  0.3164,  3.6897,  1.5100, -3.2618, -1.3348])
tensor([ 0.7731,  3.6897, -3.2618,  0.3164,  1.5100, -1.3348])


In [60]:
torch.matmul(torch.randn(1024, 2, 18), 
             torch.randn(18, 18)).size()

torch.Size([1024, 2, 18])

In [None]:
x=torch.randn(10,nq) # batch row vector
x.shape,W.shape

In [None]:
torch.matmul(x,W.T)

In [None]:
B.T.shape

In [None]:
(B.T@x.reshape(len(x),nq//n1,n1)@A)

# Batch-wise Implicit Linear Transformation via KP

In [None]:
x=torch.ones(10,nq)*x
x

In [None]:
x.shape

In [None]:
W.shape

In [None]:
torch.mm(x,W.T)

In [None]:
((B@x.reshape(len(x),nq//n1,n1))@A.T).flatten(1)

In [None]:
B.shape

In [None]:
def kron(a, b):
    """
    Kronecker product of matrices a and b with leading batch dimensions.
    Batch dimensions are broadcast. The number of them mush
    :type a: torch.Tensor
    :type b: torch.Tensor
    :rtype: torch.Tensor
    """
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
    siz0 = res.shape[:-4]
    return res.reshape(siz0 + siz1)

In [None]:
kron(x,x).shape

In [None]:
def kronecker_product_einsum_batched(A: torch.Tensor, B: torch.Tensor):
    """
    Batched Version of Kronecker Products
    :param A: has shape (b, a, c)
    :param B: has shape (b, k, p)
    :return: (b, ak, cp)
    """
    assert A.dim() == 3 and B.dim() == 3

    res = torch.einsum('bac,bkp->bakcp', A, B).view(A.size(0),
                                                    A.size(1) * B.size(1),
                                                    A.size(2) * B.size(2)
                                                    )
    return res

In [None]:
kronecker_product_einsum_batched(x,x).shape