In [None]:
from typing import Tuple
import torch

In [None]:
def sgd_factorise(A: torch.Tensor, rank: int, epochs = 1000, lr = 0.01):
  [m, n] = A.shape
  U = torch.rand(m, rank)
  V = torch.rand(rank, n)

  for epoch in range(epochs):
    for r in range(m):
      for c in range(n):
        e = A[r, c] - torch.dot(U[r, :], V[:, c])
        U[r, :] = U[r, :] + (lr * e * V[:, c])
        V[:, c] = V[:, c] + (lr * e * U[r, :])

  return U, V

In [None]:
def truncated_svd(A: torch.Tensor, truncate = -1):
  U, S, V = torch.linalg.svd(A)
  S[-1] = 0
  return U, torch.diag(S), V

NameError: ignored

In [None]:
A = torch.Tensor([[0.3374, 0.6005, 0.1735], [3.3359, 0.0492, 1.8374], [2.9407, 0.5301, 2.2620]])

U, V = sgd_factorise(A, 2)
X, S, Y = truncated_svd(A)

print(U)
print(V)

svdGradMse = torch.nn.functional.mse_loss(U@V, A, size_average=None, reduce=None, reduction='sum')
truncatedSvdMse = torch.nn.functional.mse_loss(X@S@Y, A, size_average=None, reduce=None, reduction='sum')

print("SVD Gradient error: " + str(svdGradMse))
print("Truncated SVD error: " + str(truncatedSvdMse))

tensor([[ 0.6966, -0.2414],
        [ 0.5549,  1.4885],
        [ 1.1043,  1.1078]])
tensor([[ 0.9785,  0.7543,  0.8107],
        [ 1.7955, -0.2584,  1.0558]])
SVD Gradient error: tensor(0.1310)
Truncated SVD error: tensor(0.1219)


In [None]:
def masked_factorisation(A: torch.Tensor, mask:torch.Tensor, rank:int, epochs = 1000, lr = 0.01):
  [m, n] = A.shape
  U = torch.rand(m, rank)
  V = torch.rand(rank, n)

  for epoch in range(epochs):
    for r in range(m):
      for c in range(n):
        if (mask[r, c] == 1):
          e = A[r, c] - torch.dot(U[r, :], V[:, c])
          U[r, :] = U[r, :] + (lr * e * V[:, c])
          V[:, c] = V[:, c] + (lr * e * U[r, :])
          
  return U, V

In [None]:
B = torch.Tensor([[0.3374, 0.6005, 0.1735], [0.7374, 0.0492, 1.8374], [2.9407, 0.673, 2.2620]])
mask = torch.Tensor([[1, 1, 1], [0, 1, 1], [1, 0, 1]])

U, V = masked_factorisation(B, mask, 2)
maskedFactMse = torch.nn.functional.mse_loss(U@V, A, size_average=None, reduce=None, reduction='sum')
print(maskedFactMse)
print(U@V)
print(A)

tensor(1.0302)
tensor([[0.3364, 0.6006, 0.1747],
        [2.3216, 0.0492, 1.8377],
        [2.9410, 0.4924, 2.2616]])
tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
