In [206]:
import numpy as np
import torch

In [207]:
# SVD 实现正交初始化
def orthogonal_init(shape, gain=1.):
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = np.random.normal(0., 1., flat_shape)
    u, d, vt = np.linalg.svd(a, full_matrices=False)
    # pick the one with the correct shape
#     print('u' if u.shape == flat_shape else 'v')
    q = u if u.shape == flat_shape else vt
    q = q.reshape(shape)
    return gain*q

In [208]:
shape = (4, 4)

In [209]:
x = orthogonal_init(shape)

In [210]:
# pytorch 正交初始化(PQ 分解)
w = torch.empty(shape)
torch.nn.init.orthogonal_(w)

tensor([[ 0.2355,  0.3814, -0.8635,  0.2311],
        [-0.9007,  0.0673, -0.2985, -0.3085],
        [-0.3626,  0.1880,  0.2212,  0.8856],
        [-0.0432, -0.9026, -0.3410,  0.2592]])

In [211]:
np.linalg.det(x), np.linalg.det(w)

(-1.0000000000000002, 0.99999964)

In [212]:
np.linalg.det(x@x@x@x)

1.0000000000000018

行列式为+-1, 不缩放

In [213]:
# 特征值分解
e_value, e_vector = np.linalg.eig(w)

In [214]:
print(np.allclose(e_vector@np.diag(e_value)@np.linalg.inv(e_vector), w))
e_vector@np.diag(e_value)@np.linalg.inv(e_vector)

True


array([[ 0.23550886+0.j,  0.38135552+0.j, -0.8635328 +0.j,
         0.23111524+0.j],
       [-0.9006835 +0.j,  0.06726437+0.j, -0.29849133+0.j,
        -0.30846   +0.j],
       [-0.36255038+0.j,  0.18804774+0.j,  0.22118758+0.j,
         0.8855909 +0.j],
       [-0.04315106+0.j, -0.9025969 +0.j, -0.34101295+0.j,
         0.25916567+0.j]], dtype=complex64)

In [215]:
np.allclose(w@w, e_vector@(np.diag(e_value)**2)@np.linalg.inv(e_vector))

False

https://smerity.com/articles/2016/orthogonal_init.html

$
\begin{aligned}
F^{2} &=\left(Q \Lambda Q^{-1}\right)\left(Q \Lambda Q^{-1}\right) \\
&=Q \Lambda\left(Q^{-1} Q\right) \Lambda Q^{-1} \\
&=Q \Lambda^{2} Q^{-1}
\end{aligned}
$