# Kronecker product
For any matrix $X \in R^{m \times n}$ and $Y \in R^{p \times q}$, the Kronecker product $X \otimes Y$ is a block matrix:

\begin{align*}
X \otimes Y = \begin{bmatrix} 
x_{11}Y & \dots  & x_{1n}Y \\
\vdots & \ddots & \vdots \\
x_{m1}Y & \dots  & x_{mn}Y 
\end{bmatrix} \in R^{mp \times nq},
\end{align*}

where $x_{ij}$ is the element of $X$ at its $i^{\text{th}}$ row and $j^{\text{th}}$ column.

In [1]:
import torch
e=torch.randn(10,4)
r=torch.randn(10,4)

In [2]:
x=torch.stack([e, r], 1)
x.shape

torch.Size([10, 2, 4])

In [3]:
torch.kron(x,x).shape

torch.Size([100, 4, 16])

In [4]:
def kron(a, b):
    """
    Kronecker product of matrices a and b with leading batch dimensions.
    Batch dimensions are broadcast. The number of them mush
    :type a: torch.Tensor
    :type b: torch.Tensor
    :rtype: torch.Tensor
    """
    siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:]))
    res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
    siz0 = res.shape[:-4]
    return res.reshape(siz0 + siz1)

In [5]:
kron(x,x).shape

torch.Size([10, 4, 16])

In [6]:
def kronecker_product_einsum_batched(A: torch.Tensor, B: torch.Tensor):
    """
    Batched Version of Kronecker Products
    :param A: has shape (b, a, c)
    :param B: has shape (b, k, p)
    :return: (b, ak, cp)
    """
    assert A.dim() == 3 and B.dim() == 3

    res = torch.einsum('bac,bkp->bakcp', A, B).view(A.size(0),
                                                    A.size(1) * B.size(1),
                                                    A.size(2) * B.size(2)
                                                    )
    return res

In [7]:
kronecker_product_einsum_batched(x,x).shape

torch.Size([10, 4, 16])