# Power method for SVD 

In [1]:
import torch

In [90]:
Ai = torch.stack([torch.randint(7,(10,)), torch.randint(5,(10,))])
Av =  torch.randn(10)
A = torch.sparse_coo_tensor(Ai, Av, (7, 5))
A

tensor(indices=tensor([[0, 3, 6, 2, 4, 3, 0, 6, 3, 0],
                       [0, 2, 4, 1, 0, 2, 0, 2, 4, 0]]),
       values=tensor([ 0.3126, -0.9706,  0.6253,  0.1081,  0.1108, -0.7105,
                       0.1121,  0.1582, -2.2736, -0.9067]),
       size=(7, 5), nnz=10, layout=torch.sparse_coo)

In [91]:
A.is_coalesced()

False

In [92]:
A = A.coalesce()
A

tensor(indices=tensor([[0, 2, 3, 3, 4, 6, 6],
                       [0, 1, 2, 4, 0, 2, 4]]),
       values=tensor([-0.4820,  0.1081, -1.6811, -2.2736,  0.1108,  0.1582,
                       0.6253]),
       size=(7, 5), nnz=7, layout=torch.sparse_coo)

In [93]:
A_dense = A.to_dense()
A_dense

tensor([[-0.4820,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.1081,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -1.6811,  0.0000, -2.2736],
        [ 0.1108,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1582,  0.0000,  0.6253]])

In [94]:
U,S,V = torch.svd_lowrank(A, 2)
S

tensor([2.8904, 0.4945])

In [110]:
V

tensor([[-1.1337e-07,  9.9967e-01],
        [ 1.1309e-08,  4.8412e-05],
        [ 5.8030e-01,  2.0810e-02],
        [ 0.0000e+00,  0.0000e+00],
        [ 8.1440e-01, -1.4828e-02]])

In [111]:
U

tensor([[-1.7660e-08, -9.7450e-01],
        [ 0.0000e+00,  0.0000e+00],
        [ 4.2307e-10,  1.0586e-05],
        [-9.7814e-01, -2.5706e-03],
        [-4.3470e-09,  2.2406e-01],
        [ 0.0000e+00,  0.0000e+00],
        [ 2.0794e-01, -1.2092e-02]])

In [95]:
torch.dist(A_dense, U @ torch.diag_embed(S) @ V.t())

tensor(0.2628)

In [117]:
U_full,S_full,V_full = torch.svd_lowrank(A, 5)
print(S_full)
torch.dist(A_dense, U_full @ torch.diag_embed(S_full) @ V_full.t())

tensor([2.8904, 0.4946, 0.2392, 0.1081, 0.0000])


tensor(9.2431e-07)

In [80]:
A.t()

tensor(indices=tensor([[0, 3, 5, 0, 3, 4, 3, 4],
                       [0, 1, 1, 2, 2, 2, 3, 3]]),
       values=tensor([ 0.2899, -0.1359,  0.8700,  0.7641, -1.0162,  0.3804,
                      -1.2005,  0.0538]),
       size=(7, 4), nnz=8, layout=torch.sparse_coo)

In [98]:
ATA = torch.sparse.mm(A.t(), A)
ATA

tensor(indices=tensor([[0, 1, 2, 2, 4, 4],
                       [0, 1, 4, 2, 4, 2]]),
       values=tensor([0.2446, 0.0117, 3.9212, 2.8513, 5.5602, 3.9212]),
       size=(5, 5), nnz=6, layout=torch.sparse_coo)

In [121]:
v0 = V.t()[0]
v0

tensor([-1.1337e-07,  1.1309e-08,  5.8030e-01,  0.0000e+00,  8.1440e-01])

In [113]:
ATA @ v0

tensor([-2.7735e-08,  1.3223e-10,  4.8480e+00,  0.0000e+00,  6.8037e+00])

In [112]:
torch.dist(ATA @ v0, S[0].pow(2) * v0)

tensor(1.2248e-05)

In [123]:
v = torch.randn(5)
v /= v.norm()
v

tensor([ 0.2707,  0.4785, -0.5180,  0.6502,  0.0818])

In [154]:
v = ATA @ v
v /= v.norm()
v

tensor([ 1.5700e-40,  0.0000e+00, -5.8030e-01,  0.0000e+00, -8.1440e-01])

In [158]:
torch.dist(v,-v0)

tensor(1.4474e-06)

In [159]:
vp = ATA @ v
vp

tensor([ 3.8410e-41,  0.0000e+00, -4.8480e+00,  0.0000e+00, -6.8037e+00])

In [164]:
s = torch.sqrt(vp.norm())
s

tensor(2.8904)

In [167]:
torch.dist(S[0],s)

tensor(2.3842e-07)

In [169]:
u = (A @ v) / s
u

tensor([-2.6183e-41,  0.0000e+00,  0.0000e+00,  9.7814e-01,  6.0200e-42,
         0.0000e+00, -2.0794e-01])

In [173]:
u0 = U.t()[0]
torch.dist(-u0, u)

tensor(1.3452e-07)

In [178]:
uOv = torch.outer(u,v)
uOv

tensor([[-0.0000e+00, -0.0000e+00,  1.5194e-41, -0.0000e+00,  2.1324e-41],
        [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [ 1.5357e-40,  0.0000e+00, -5.6762e-01,  0.0000e+00, -7.9660e-01],
        [ 0.0000e+00,  0.0000e+00, -3.4934e-42,  0.0000e+00, -4.9031e-42],
        [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [-3.2646e-41, -0.0000e+00,  1.2067e-01, -0.0000e+00,  1.6935e-01]])

In [180]:
uOv = uOv.to_sparse()
uOv

tensor(indices=tensor([[0, 0, 3, 3, 3, 4, 4, 6, 6, 6],
                       [2, 4, 0, 2, 4, 2, 4, 0, 2, 4]]),
       values=tensor([ 1.5194e-41,  2.1324e-41,  1.5357e-40, -5.6762e-01,
                      -7.9660e-01, -3.4934e-42, -4.9031e-42, -3.2646e-41,
                       1.2067e-01,  1.6935e-01]),
       size=(7, 5), nnz=10, layout=torch.sparse_coo)

In [189]:
Ap = A - (uOv * s)
Ap

tensor(indices=tensor([[0, 0, 0, 2, 3, 3, 3, 4, 4, 4, 6, 6, 6],
                       [0, 2, 4, 1, 0, 2, 4, 0, 2, 4, 0, 2, 4]]),
       values=tensor([-4.8204e-01, -4.3917e-41, -6.1633e-41,  1.0813e-01,
                      -4.4386e-40, -4.0514e-02,  2.8869e-02,  1.1083e-01,
                       1.0098e-41,  1.4171e-41,  9.4359e-41, -1.9057e-01,
                       1.3579e-01]),
       size=(7, 5), nnz=13, layout=torch.sparse_coo)

In [190]:
Ap.is_coalesced()

True

In [215]:
def svd1D(A, tol=1e-12, max_iter=20):
    currentV = torch.randn(5)
    currentV /= currentV.norm()
    lastV = torch.zeros_like(currentV)
    ATA = torch.sparse.mm(A.t(), A)
    
    iterations = 0
    while torch.dist(currentV,lastV) > tol :
        iterations += 1
        lastV = currentV
        currentV = ATA @ lastV
        currentV /= currentV.norm()

        if iterations > max_iter:
            print(f"SVD failed to converge in {iterations} iterations:") 
            print(f"{torch.dist(currentV,lastV)} > {tol}")
            break
            
    return currentV

In [216]:
vp = svd1D(Ap)

In [218]:
torch.dist(vp, V.t()[1])

tensor(0.0256)

In [222]:
def get_singular_value(ATA, v):
    s2v = ATA @ v
    s = torch.sqrt(s2v.norm())
    return s
sp = get_singular_value(ATA, vp)
sp

tensor(0.4946)

In [223]:
S[1]

tensor(0.4945)

In [224]:
up = (A @ vp) / sp
up

tensor([-9.7457e-01,  0.0000e+00, -3.6240e-28, -9.4167e-15,  2.2407e-01,
         0.0000e+00, -4.4295e-14])

In [225]:
torch.dist(up, U.t()[1])

tensor(0.0124)