In [2]:
import torch
def pseudoinv_via_newtonschulz5(G, Q_0=None, steps=10, eps=1e-7):
    assert len(G.shape) == 2
    if Q_0 is not None:
        Q = Q_0.bfloat16()
    else:
        Q = G.bfloat16().T
        Q /= Q.norm()**2 + eps  # ensure top singular value <= 1
    G_bf = G.bfloat16()
    if G.size(0) > G.size(1):
        Q = Q.T
    for _ in range(steps):
        Q = 2 * Q - Q @ G_bf @ Q
    if G.size(0) > G.size(1):
        Q = Q.T

    return Q.float()

In [22]:
def get_grad(G, B):
    Q = torch.linalg.pinv(B) @ G
    return Q

def get_grad_my(G, B, num_steps=10):
    Q = pseudoinv_via_newtonschulz5(B, steps=num_steps) @ G
    return Q

In [10]:
G = torch.tensor([[0, 2], [0, 4], [0, 8]], dtype=float)
print(G)
print(G.norm(dim=0, keepdim=True).clamp(min=1e-8))
G /= G.norm(dim=0, keepdim=True)

G

tensor([[0., 2.],
        [0., 4.],
        [0., 8.]], dtype=torch.float64)
tensor([[0.0000, 9.1652]], dtype=torch.float64)


tensor([[   nan, 0.2182],
        [   nan, 0.4364],
        [   nan, 0.8729]], dtype=torch.float64)

In [28]:
m, n = 2, 10
A = torch.diag(torch.randn(m)) # scale
# A = torch.randn(m, m) # affine
# A = A.T @ A
N = 10
for i in range(N):
    G_1 = torch.randn(m, n)
    B_1 = torch.diag(torch.diag(G_1 @ G_1.T))
    Q_1 = get_grad(G_1, B_1)
    Q_my = get_grad_my(G_1, B_1, num_steps=2000)
    G_2 = A @ G_1
    B_2 = torch.diag(torch.diag(G_2 @ G_2.T))
    Q_2 = get_grad(G_2, B_2)
    print(f"Step{i}. Q_1 - A @ Q_2 = {(Q_1 - A @ Q_2).norm()}, Q_1 - Q_my = {(Q_1 - Q_my).norm()}")
#print(f">>> {1 / N * (Q_1_avg - Q_3_avg).norm()} <<<")
#print(f">>> {1 / N * (Q_3_avg - A @ Q_2_avg).norm()} <<<")

Step0. Q_1 - A @ Q_2 = 2.061570469891194e-08, Q_1 - Q_my = 0.00040847985656000674
Step1. Q_1 - A @ Q_2 = 4.5663899328474145e-08, Q_1 - Q_my = 0.001365800853818655
Step2. Q_1 - A @ Q_2 = 5.4567873775113185e-08, Q_1 - Q_my = 0.0017510391771793365
Step3. Q_1 - A @ Q_2 = 2.3302845875150524e-08, Q_1 - Q_my = 0.00036933328374288976
Step4. Q_1 - A @ Q_2 = 6.084324155608556e-08, Q_1 - Q_my = 0.0016209579771384597
Step5. Q_1 - A @ Q_2 = 3.3140562294420306e-08, Q_1 - Q_my = 0.00024168229720089585
Step6. Q_1 - A @ Q_2 = 4.713541201795124e-08, Q_1 - Q_my = 0.0005427215364761651
Step7. Q_1 - A @ Q_2 = 3.9295990461596375e-08, Q_1 - Q_my = 0.000812358281109482
Step8. Q_1 - A @ Q_2 = 6.410077446616924e-08, Q_1 - Q_my = 0.002119492506608367
Step9. Q_1 - A @ Q_2 = 3.3385834541377335e-08, Q_1 - Q_my = 0.0008281836635433137


In [14]:
G_2

tensor([-0.8974,  2.5341])

In [7]:
G_1 @ G_1.T

tensor([[ 2.4193, -0.5006, -2.8286, -1.0936,  1.1911, -2.7360,  1.0512,  1.7063,
         -0.2348, -1.8906],
        [-0.5006,  3.9896, -0.7531, -5.3247, -0.6357,  3.0742, -2.6383,  0.6293,
         -2.9536,  3.2870],
        [-2.8286, -0.7531,  3.7679,  3.1903, -1.2585,  2.3350, -0.3953, -2.3333,
          1.3085,  1.2131],
        [-1.0936, -5.3247,  3.1903,  8.4235,  0.0176, -2.3459,  2.9827, -2.1746,
          4.3946, -3.2818],
        [ 1.1911, -0.6357, -1.2585,  0.0176,  0.6254, -1.5982,  0.7600,  0.7417,
          0.1851, -1.2209],
        [-2.7360,  3.0742,  2.3350, -2.3459, -1.5982,  4.7128, -2.7512, -1.2956,
         -1.6721,  4.0070],
        [ 1.0512, -2.6383, -0.3953,  2.9827,  0.7600, -2.7512,  1.9648,  0.1295,
          1.7682, -2.6254],
        [ 1.7063,  0.6293, -2.3333, -2.1746,  0.7417, -1.2956,  0.1295,  1.4518,
         -0.9245, -0.6014],
        [-0.2348, -2.9536,  1.3085,  4.3946,  0.1851, -1.6721,  1.7682, -0.9245,
          2.3422, -2.0537],
        [-1.8906,  