In [1]:
%matplotlib inline
import numpy as np
import torch
import matplotlib.pyplot as plt
from typing import Tuple

In [2]:
def sgd_factorise(A:torch.Tensor, rank:int, num_epochs=1000,
                 lr=0.01) -> Tuple[torch.Tensor, torch.Tensor]:
    U, V = torch.rand(A.shape[0], rank), torch.rand(A.shape[1], rank)
    for i in range(num_epochs):
        for r in range(A.shape[0]):
            for c in range(A.shape[1]):
                e = A[r][c] - (U[r,:] @ V[c,:].t())
                U[r,:] = U[r,:] + lr * e * V[c,:]
                V[c,:] = V[c,:] + lr * e *U[r,:]
    return U,V

In [3]:
A = torch.tensor([[0.3374, 0.6005, 0.1735],
                  [3.3359, 0.0492, 1.8374],
                  [2.9407, 0.5301, 2.2620]
                 ])
U,V = sgd_factorise(A, 2)

In [4]:
recon_A = U @ V.t()

In [5]:
print(torch.nn.functional.mse_loss(recon_A, A, reduction='sum'))

tensor(0.1225)


In [6]:
recon_A

tensor([[ 0.2175,  0.4999,  0.3702],
        [ 3.2599, -0.0145,  1.9622],
        [ 3.0285,  0.6039,  2.1178]])

In [7]:
U1,S1,V1 = torch.svd(A)

In [10]:
S1[-1] = 0
recon_A_ = U1 @ torch.diag(S1) @ V1.t()

In [11]:
recon_A_

tensor([[ 0.2245,  0.5212,  0.3592],
        [ 3.2530, -0.0090,  1.9737],
        [ 3.0378,  0.5983,  2.1023]])

In [12]:
print(torch.nn.functional.mse_loss(recon_A_, A, reduction='sum'))

tensor(0.1219)
