In [1]:
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

# All requested packages already installed.


Note: you may need to restart the kernel to use updated packages.


In [24]:
import torch
from typing import Tuple

def sgd_factorise(A: torch.Tensor, rank: int, num_epochs=1000, lr = 0.01) ->Tuple[torch.Tensor, torch.tensor]:
    N = num_epochs
    m = A.size()[0]
    n = A.size()[1]
    r = rank
    U = torch.randn((m,r))
    V = torch.randn((n,r))
    for epoch in range(N):
        for r in range(m):
            for c in range(n):
                e = A[r][c] - U[r,:]@V[c,:].t()
                U[r,:] += lr*e*V[c,:]
                V[c,:] += lr*e*U[r,:]
    return U, V     

In [135]:
A = torch.Tensor([[0.3374,0.6005,0.1735],[3.3359,0.0492,1.8374],[2.9407, 0.5301,2.2620]]) 
U_,V_ = sgd_factorise(A,2)
print("U_:",U_)
print("V_:",V_)
print("A*:",U_@V_.t())
print("A:",A)
print("Error:",torch.nn.functional.mse_loss(U_@V_.t(),A,reduction ='sum'))

U_: tensor([[ 0.5913, -0.4994],
        [ 1.5092,  0.8254],
        [ 1.9775,  0.1241]])
V_: tensor([[ 1.4551,  1.2777],
        [ 0.3416, -0.6369],
        [ 1.0341,  0.5031]])
A*: tensor([[ 0.2222,  0.5201,  0.3602],
        [ 3.2506, -0.0101,  1.9759],
        [ 3.0360,  0.5966,  2.1074]])
A: tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
Error: tensor(0.1220)


In [138]:
import torch
U,S,V_T = torch.svd(A)
S[2] = 0
SS_ = torch.diag(S)
A_ = U@SS_@V_T.t()
print("U:",U)
print("SS_",SS_)
print("V_T:",V_T)
print("A*:",U@SS_@V_T.t())
print("A:",A)
print("Error:",torch.nn.functional.mse_loss(U@SS_@V_T.t(),A,reduction ='sum'))

U: tensor([[-0.0801, -0.7448,  0.6625],
        [-0.7103,  0.5090,  0.4863],
        [-0.6994, -0.4316, -0.5697]])
SS_ tensor([[5.3339, 0.0000, 0.0000],
        [0.0000, 0.6959, 0.0000],
        [0.0000, 0.0000, 0.0000]])
V_T: tensor([[-0.8349,  0.2548,  0.4879],
        [-0.0851, -0.9355,  0.3430],
        [-0.5439, -0.2448, -0.8027]])
A*: tensor([[ 0.2245,  0.5212,  0.3592],
        [ 3.2530, -0.0090,  1.9737],
        [ 3.0378,  0.5983,  2.1023]])
A: tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
Error: tensor(0.1219)


In [10]:
import torch
from scipy.linalg import sqrtm
U,S,V_T = torch.svd(A)
print("S:",S)
S[2] = 0
SS_ = torch.diag(S)
U_ = U@sqrtm(SS_)
V_T_ = V_T@sqrtm(SS_)
A_ = U_@V_T_.t()
print("U_:",U)
print("V_T_:",V_T_)
print("A*:",A_)
print("A:",A)
print("Error:",torch.nn.functional.mse_loss(A_,A,reduction ='sum'))

S: tensor([5.3339, 0.6959, 0.3492])
U_: tensor([[-0.0801, -0.7448,  0.6625],
        [-0.7103,  0.5090,  0.4863],
        [-0.6994, -0.4316, -0.5697]])
V_T_: tensor([[-1.9281,  0.2126,  0.0000],
        [-0.1965, -0.7804,  0.0000],
        [-1.2561, -0.2042,  0.0000]], dtype=torch.float64)
A*: tensor([[ 0.2245,  0.5212,  0.3592],
        [ 3.2530, -0.0090,  1.9737],
        [ 3.0378,  0.5983,  2.1023]], dtype=torch.float64)
A: tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
Error: tensor(0.1219, dtype=torch.float64)


In [139]:
def sgd_factorise_masked(A: torch.Tensor, M:torch.Tensor, rank: int, num_epochs=1000, lr = 0.01) ->Tuple[torch.Tensor, torch.tensor]:
    N = num_epochs
    m = A.size()[0]
    n = A.size()[1]
    r = rank
    U = torch.randn((m,r))
    V = torch.randn((n,r))
#     U = torch.zeros()
    
    for epoch in range(N):
#         e = A - U@V.t()
        for r in range(m):
            for c in range(n):
                if M[r][c] == 1:
                    e = A[r][c] - U[r,:]@V[c,:].t()
                    U[r,:] += lr*e*V[c,:]
                    V[c,:] += lr*e*U[r,:]
                else:
                    continue            
    return U, V 

In [377]:
A = torch.Tensor([[0.3374,0.6005,0.1735],[3.3359,0.0492,1.8374],[2.9407, 0.5301,2.2620]]) 
M = torch.Tensor([[1,1,1],[0,1,1],[1,0,1]]) 
U__,V__ = sgd_factorise_masked(A,M,2)
print("U__:",U__)
print("V__:",V__)
print("A**:",U__@V__.t())
print("A:",A)
print("Error:",torch.nn.functional.mse_loss(U__@V__.t(),A,reduction ='sum'))

U__: tensor([[-0.0023,  0.2080],
        [ 1.1480,  1.0990],
        [ 1.1147,  1.3024]])
V__: tensor([[ 1.6743,  0.8307],
        [-1.4381,  1.5634],
        [-0.4038,  2.0792]])
A**: tensor([[0.1690, 0.3284, 0.4334],
        [2.8350, 0.0671, 1.8215],
        [2.9481, 0.4330, 2.2577]])
A: tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
Error: tensor(0.4309)


In [345]:
U__= torch.Tensor([[ 0.5913, -0.4994],
        [ 1.5092,  0.8254],
        [ 1.9775,  0.1241]])
V__= torch.Tensor([[ 1.4551,  1.2777],
        [ 0.3416, -0.6369],
        [ 1.0341,  0.5031]])
print(U__@V__.t())
print("Error:",torch.nn.functional.mse_loss(U__@V__.t(),A,reduction ='sum'))

tensor([[ 0.2223,  0.5201,  0.3602],
        [ 3.2507, -0.0102,  1.9759],
        [ 3.0360,  0.5965,  2.1074]])
Error: tensor(0.1220)
