In [95]:
import torch
import torch.linalg as L
from torch.autograd.functional import jvp

import numpy as np

## 1
$ f(x) = \|A - xx^\top\|_F^2, \quad A = A^\top \in \mathbb{R}^{n \times n}, \quad x \in \mathbb{R}^n$ 

In [96]:
torch.manual_seed(0)

x = torch.randn(5, requires_grad=True)
A = torch.randn(5, 5)
y = L.norm(A - torch.outer(x, x))**2

y.backward()

In [97]:
x_grad = -2 * (A + A.T - 2 * torch.outer(x, x)) @ x
torch.allclose(x_grad, x.grad)

True

## 2
$f(X) = \text{Tr} \left( (X^\top X)^{-1} X^\top A X \right), \quad X \in \mathbb{R}^{n \times p}, \quad \text{rank}(X) = p, \quad A = A^\top$

In [98]:
n = 4
p = 3
X = torch.randn(n, p, requires_grad=True)
A = torch.randn(n, n)
H = torch.randn_like(X)

In [99]:
torch.allclose(
    jvp(func=lambda X: X.T @ X, inputs=X,v=H)[1],
    H.T @ X + X.T @ H
)

True

In [100]:
torch.allclose(
    jvp(func=lambda X: X.T @ A @ X, inputs=X, v=H)[1],
    H.T @ A @ X + X.T @ A @ H
)

True

In [101]:
torch.allclose(
    jvp(func=lambda X: L.inv(X.T @ X), inputs=X, v=H)[1],
    -L.inv(X.T @ X) @ (H.T @ X + X.T @ H) @ L.inv(X.T @ X)
)

True

In [102]:
torch.allclose(
    jvp(func=lambda X: L.inv(X.T @ X) @ X.T @ A @ X, inputs=X, v=H)[1],
    -L.inv(X.T @ X) @ (H.T @ X + X.T @ H) @ L.inv(X.T @ X) @ X.T @ A @ X + L.inv(X.T @ X) @ (H.T @ A @ X + X.T @ A @ H)
)

True

In [103]:
torch.allclose(
    jvp(func=lambda X: torch.trace(L.inv(X.T @ X) @ X.T @ A @ X), inputs=X, v=H)[1],
    torch.trace((-(X @ L.inv(X.T @ X) @ X.T @ (A + A.T) @ X @ L.inv(X.T @ X)) + (A + A.T) @ X @ L.inv(X.T @ X)).T @ H)
)

True

In [104]:
y = torch.trace(L.inv(X.T @ X) @ X.T @ A @ X)
y.backward()

torch.allclose(
    X.grad,
    -(X @ L.inv(X.T @ X) @ X.T @ (A + A.T) @ X @ L.inv(X.T @ X)) + (A + A.T) @ X @ L.inv(X.T @ X)
)

True

## 3
$f(X) = \text{Tr}(X \odot X), \quad X \in \mathbb{R}^{m \times n}$

In [105]:
m = 3
n = 3
X = torch.randn(m, n, requires_grad=True)
H = torch.randn_like(X)

In [106]:
torch.allclose(
    jvp(func=lambda X:torch.trace(X * X), inputs=X, v=H)[1],
    2 * torch.trace(torch.diag(X.diag()).T @ torch.diag(H.diag()))
)

True

In [107]:
y = torch.trace(X * X)
y.backward()
X.grad

torch.allclose(
    X.grad,
    2 * torch.diag(X.diag())
)

True

## 4

$f(X) = \text{Tr}(X \operatorname{diag}(X)), \quad X \in \mathbb{R}^{n \times n}, \quad \operatorname{diag}(X) = \operatorname{diag}(x_{11}, \dots, x_{nn})$

In [108]:
m = 3
n = 3
X = torch.randn(m, n, requires_grad=True)
H = torch.randn_like(X)

y = torch.trace(X * torch.diag(X))

y.backward()
X.grad

tensor([[-1.1149,  0.0000,  0.0000],
        [ 0.0000, -1.4213,  0.0000],
        [ 0.0000,  0.0000,  0.7594]])

In [109]:
torch.allclose(
    jvp(func=lambda X: torch.trace(X * torch.diag(X)), inputs=X, v=H)[1],
    torch.trace(H * torch.diag(X) + X * torch.diag(H))
)

True

In [110]:
torch.allclose(
    X.grad,
    2 * torch.diag(X.diag())
)

True

In [111]:
X.grad, torch.diag(X)

(tensor([[-1.1149,  0.0000,  0.0000],
         [ 0.0000, -1.4213,  0.0000],
         [ 0.0000,  0.0000,  0.7594]]),
 tensor([-0.5575, -0.7107,  0.3797], grad_fn=<DiagonalBackward0_copy>))

## 5

$ f(X) = a^\top X^2 b, \quad a, b \in \mathbb{R}^n, \quad X \in \mathbb{R}^{n \times n} $

In [112]:
n = 3
X = torch.randn(n, n, requires_grad=True)
a = torch.randn(n)
b = torch.randn_like(a)
H = torch.randn_like(X)

y = torch.einsum('i,ij,j->', a, X @ X, b)

y.backward()

torch.allclose(
    X.grad,
    torch.outer(a, b) @ X.T + X.T @ torch.outer(a, b)
)

True

## 6

$ f(X) = \text{Tr}(I \otimes X + X \otimes I), \quad X \in \mathbb{R}^{n \times n}$ 

In [113]:
n = 3
X = torch.randn(n, n, requires_grad=True)
I = torch.eye(n)
H = torch.randn_like(X)

y = torch.trace(torch.kron(I, X) + torch.kron(X, I))

y.backward()

torch.allclose(
    X.grad,
    2 * n * I
)

True

## 7

$f(U) = F(W + UV^\top), \quad g(V) = F(W + UV^\top), \quad W \in \mathbb{R}^{m \times n}, \quad U \in \mathbb{R}^{m \times r}, \quad V \in \mathbb{R}^{n \times r}$

In [114]:
m = 4
n = 3
r = 2

W = torch.randn(m, n)
U = torch.randn(m, r, requires_grad=True)
V = torch.randn(n, r, requires_grad=True)

H = torch.randn_like(U)

F = lambda X: torch.trace(2 * X**2 + torch.cos(X))

X = W + U @ V.T
X.retain_grad()
y = F(X)

y.backward()

In [115]:
torch.allclose(
    U.grad,
    X.grad @ V 
)

True

In [116]:
torch.allclose(
    V.grad,
    X.grad.T @ U
)

True

In [117]:
m1, m2, n1, n2 = 4, 3, 3, 2

A = torch.tensor([[i + j for i in range(1, n1 * n2+1)] for j in range(1, m1 * m2+1)])
B1 = torch.ones((m1, n1), dtype=int) #torch.tensor([[i + j for j in range(1, n1+1)] for i in range(1, m1+1)]) 
C1 = torch.tensor([[i + j for j in range(1, n2+1)] for i in range(1, m2+1)]) 
B2 = torch.tensor([[m2 * i + n2 * j for j in range(n1)] for i in range(m1)])
C2 = torch.ones((m2, n2), dtype=int)

torch.all(torch.kron(B1, C1) + torch.kron(B2, C2) == A)

tensor(True)

In [118]:
n = 25
B = np.array(np.random.randn(n, n, n), order='F')
U = np.array(np.random.randn(n, n), order='F')
y1 = (np.kron(np.kron(U, U), U).T @ B.flatten('F'))
y2 = np.einsum('abc,ia,jb,kc->ijk', B, U.T, U.T, U.T, optimize='optimal')

In [119]:
y1.shape

(15625,)

In [120]:
np.allclose(y1, y2.flatten("F"))

True

In [None]:
# %%timeit
# (np.kron(np.kron(U, U), U).T @ B.flatten('F'))

2.04 s ± 179 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# %%timeit
# np.einsum('abc,ai,bj,ck->ijk', B, U, U, U, optimize='optimal').flatten("F")

683 μs ± 58.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [123]:
n = 4
X = np.random.randn(n, n, n)
A = np.random.randn(n, n)
A = A + A.T
I = np.eye(n)
B = np.einsum('abc,ia,jb,kc->ijk', X, A, I, I, optimize="optimal") + \
    np.einsum('abc,ia,jb,kc->ijk', X, I, A, I, optimize="optimal") + \
    np.einsum('abc,ia,jb,kc->ijk', X, I, I, A, optimize="optimal")

In [124]:
d = np.random.randn(n)
C = np.random.randn(n, n)
np.allclose(C @ np.diag(d), C * d), np.allclose(np.diag(d) @ C, (C.T * d[None, :]).T)

(True, True)

In [125]:
S, U = np.linalg.eigh(A)

In [126]:
X1 = np.einsum('abc,ia,jb,kc->ijk', B, U.T, U.T, U.T, optimize='optimal')
X2 = ((np.diagonal(np.kron(np.kron(I, I), np.diag(S)) + np.kron(np.kron(I, np.diag(S)), I) + np.kron(np.kron(np.diag(S), I), I))**-1) * X1.flatten("F")).reshape(n, n, n, order="F")
X_found = np.einsum('abc,ia,jb,kc->ijk', X2, U, U, U, optimize='optimal')

In [127]:
np.allclose(X, X_found)

True

In [128]:
n = 100
U = np.random.randn(n, n)
B = np.random.randn(n, n)

np.allclose(U @ B @ U.T, np.einsum('ab,ia,jb->ij', B, U, U))

True

In [129]:
# %%timeit
# U @ B @ U.T

In [130]:
# %%timeit
# np.einsum('ab,ia,jb->ij', B, U, U, optimize='optimal')

In [131]:
m1, n1, m2, n2 = 3, 2, 2, 2

A = np.random.randn(m1 * m2, n1 * n2)
B = np.random.randn(m1, n1)
C = np.random.randn(m2, n2)

In [132]:
import itertools

def perfect_shuffle(p, r):
    return np.array(list(itertools.chain(*(range(i, p * r, r) for i in range(r)))))

$\mathcal{P}_{m1, m2} (B \otimes C) \mathcal{P}_{n1, n2}^\top = (C \otimes B)$

In [133]:
np.all(
        np.kron(B, C)[perfect_shuffle(m1, m2), :][:, perfect_shuffle(n1, n2)] == 
        np.kron(C, B)
    )

True

In [134]:
M = np.array([[f"{i}{j}" for j in range(1, n1 * n2 + 1)] for i in range(1, m1 * m2 + 1)])
M

array([['11', '12', '13', '14'],
       ['21', '22', '23', '24'],
       ['31', '32', '33', '34'],
       ['41', '42', '43', '44'],
       ['51', '52', '53', '54'],
       ['61', '62', '63', '64']], dtype='<U2')

In [135]:
RM_1 = M.reshape(m1, n1, m2 * n2, order='C').transpose(0, 2, 1).reshape(m1, m2, n1, n2, order='C').transpose(0, 1, 3, 2).reshape(m1 * n1, m2 * n2, order='F')

In [136]:
RM_2 = M.reshape(m2, m1, n2, n1, order='F').transpose(0, 2, 1, 3).reshape(m2 * n2, m1, n1, order='F').\
transpose(1, 2, 0).reshape(m1 * n1, m2 * n2, order='F')

In [137]:
np.all(RM_1 == RM_2)

True