<a href="https://colab.research.google.com/github/Shurui-Zhang/Deep_learning/blob/main/Lab1ex.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import Tuple
import torch
import random
import numpy as np

def sgd_factorise(A:torch.Tensor, rank:int, num_epochs =1000, lr =0.01) -> Tuple[torch.Tensor, torch. Tensor]:
  
  m, n = A.shape
  r = rank

  U = torch.normal(0, 1, size=(m, r))
  V = torch.normal(0, 1, size=(n, r))

  for epoch in range(num_epochs):
    for row in range(m):
      for column in range(n):
        error = A[row, column] - U[row, :] @ V[column, :].t()
        U[row, :] = U[row, :] + lr * error * V[column, :]
        V[column, :] = V[column, :] + lr * error * U[row, :]
  
  return U, V


In [None]:
z = torch.tensor([[0.3374, 0.6005, 0.1735], [3.3359, 0.0492, 1.8374], [2.9407, 0.5301, 2.2620]], dtype=torch.float)
print(np.linalg.matrix_rank(z))
U, V = sgd_factorise(z, 2)
print("U = ", U)
print("V = ", V)
print("error =", torch.nn.functional.mse_loss(U@V.t(), z, reduction='sum'))

3
U =  tensor([[ 0.1509, -0.1521],
        [ 1.6671,  0.0970],
        [ 1.5327, -0.3966]])
V =  tensor([[ 1.9936,  0.3157],
        [ 0.1311, -0.6856],
        [ 1.1526, -1.3099]])
error = tensor(0.2907)


In [None]:
z_svd = torch.svd(z)
print(z_svd)
print(z_svd[0] @ torch.diag(z_svd[1]) @ z_svd[2].t() )
z_svd[1][2] = 0
print("the result of using torch.svd:")
print(z_svd)
print("the result of reconstruction:")
print(z_svd[0] @ torch.diag(z_svd[1]) @ z_svd[2].t())
print("error = ", torch.nn.functional.mse_loss(z_svd[0] @ torch.diag(z_svd[1]) @ z_svd[2].t(), z, reduction='sum'))
print(z_svd[0] @ z_svd[0].t())

torch.return_types.svd(
U=tensor([[-0.0801, -0.7448,  0.6625],
        [-0.7103,  0.5090,  0.4863],
        [-0.6994, -0.4316, -0.5697]]),
S=tensor([5.3339, 0.6959, 0.3492]),
V=tensor([[-0.8349,  0.2548,  0.4879],
        [-0.0851, -0.9355,  0.3430],
        [-0.5439, -0.2448, -0.8027]]))
tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
the result of using torch.svd:
torch.return_types.svd(
U=tensor([[-0.0801, -0.7448,  0.6625],
        [-0.7103,  0.5090,  0.4863],
        [-0.6994, -0.4316, -0.5697]]),
S=tensor([5.3339, 0.6959, 0.0000]),
V=tensor([[-0.8349,  0.2548,  0.4879],
        [-0.0851, -0.9355,  0.3430],
        [-0.5439, -0.2448, -0.8027]]))
the result of reconstruction:
tensor([[ 0.2245,  0.5212,  0.3592],
        [ 3.2530, -0.0090,  1.9737],
        [ 3.0378,  0.5983,  2.1023]])
error =  tensor(0.1219)
tensor([[1.0000e+00, 3.1216e-07, 1.2579e-07],
        [3.1216e-07, 1.0000e+00, 1.4313e-07],
        [1.2579e-07, 1.4313e

In [None]:
def sgd_factorise_masked(A:torch.Tensor, M: torch.Tensor, rank:int, num_epochs =1000, lr =0.01) -> Tuple[torch.Tensor, torch. Tensor]:
  
  m, n = A.shape
  r = rank

  U = torch.normal(0, 1, size=(m, r))
  V = torch.normal(0, 1, size=(n, r))
  print("initial metricx",U@V.t())
  for epoch in range(num_epochs):
    for row in range(m):
      for column in range(n):
        if M[row, column] == 1:
          error = A[row, column] - U[row, :] @ V[column, :].t()
          U[row, :] = U[row, :] + lr * error * V[column, :]
          V[column, :] = V[column, :] + lr * error * U[row, :]
  
  return U, V

In [None]:
original_matrix = torch.tensor([[0.3374, 0.6005, 0.1735], [0, 0.0492, 1.8374], [2.9407, 0, 2.2620]], dtype=torch.float)
mask_matrix = torch.tensor([[1, 1, 1], [0, 1, 1], [1, 0, 1]], dtype=torch.float)
U_2, V_2 = sgd_factorise_masked(original_matrix, mask_matrix, 3)
#print("U = ", U_2)
#print("V = ", V_2)
completed_matrix = U@V.t()
print("the original matrix:")
print(z)
print("the matrix after reconstruction:")
print(completed_matrix)


initial metricx tensor([[-4.0649, -0.7557,  2.8187],
        [-2.2092, -3.5939, -0.8702],
        [ 1.3050, -2.4338, -2.4134]])
the original matrix:
tensor([[0.3374, 0.6005, 0.1735],
        [3.3359, 0.0492, 1.8374],
        [2.9407, 0.5301, 2.2620]])
the matrix after reconstruction:
tensor([[0.2528, 0.1241, 0.3732],
        [3.3542, 0.1520, 1.7944],
        [2.9304, 0.4728, 2.2861]])
