In [4]:
import torch

def alternating_least_squares(M, latent_vector_size, lambda_reg=0.01, max_iter=100, tol=1e-4):
    """
    Alternating Least Squares (ALS) algorithm for matrix factorization.

    Parameters:
    - M (torch.Tensor): Ratings matrix of shape (n_users, n_items).
    - latent_vector_size (int): Number of latent factors.
    - lambda_reg (float): Regularization parameter.
    - max_iter (int): Maximum number of iterations.
    - tol (float): Tolerance for convergence.

    Returns:
    - U (torch.Tensor): User factors matrix of shape (n_users, latent_vector_size).
    - I (torch.Tensor): Item factors matrix of shape (n_items, latent_vector_size).
    """

    n_users, n_items = M.shape

    # Initialize user and item factors randomly
    U = torch.randn(n_users, latent_vector_size)
    I = torch.randn(n_items, latent_vector_size)

    def update_factors(M, U, I, lambda_reg):
        """
        Update user factors U and item factors I using ALS.

        Parameters:
        - M (torch.Tensor): Ratings matrix of shape (n_users, n_items).
        - U (torch.Tensor): User factors matrix of shape (n_users, latent_vector_size).
        - I (torch.Tensor): Item factors matrix of shape (n_items, latent_vector_size).
        - lambda_reg (float): Regularization parameter.

        Returns:
        - U_updated (torch.Tensor): Updated user factors matrix of shape (n_users, latent_vector_size).
        - I_updated (torch.Tensor): Updated item factors matrix of shape (n_items, latent_vector_size).
        """

        # Update user factors U
        for u in range(n_users):
            relevant_items = torch.nonzero(M[u, :] > 0).squeeze()
            I_rel = I[relevant_items]
            M_u = M[u, relevant_items]
            A = (I_rel.t() @ I_rel) + lambda_reg * torch.eye(latent_vector_size)
            V = (I_rel.t() @ M_u.unsqueeze(-1)).squeeze()
            U[u] = torch.linalg.solve(A, V)

        # Update item factors I
        for i in range(n_items):
            # torch.nonzero(M[:, i] > 0) shape: (num_relevant_items, 1)
            # relevant_users shape:             (num_relevant_items)
            relevant_users = torch.nonzero(M[:, i] > 0).squeeze()
            
            # U_rel shape:  (num_relevant_users, latent_vector_size)
            # M_i shape:    (num_relevant_users)
            U_rel = U[relevant_users]
            M_i = M[relevant_users, i]
            
            # A shape: (latent_vector_size, latent_vector_size)
            A = (U_rel.t() @ U_rel) + lambda_reg * torch.eye(latent_vector_size)
            
            # M_i.unsqueeze(-1) shape:                  (num_relevant_users, 1)
            # (U_rel.t() @ M_i.unsqueeze(-1)) shape:    (latent_vector_size, 1)
            # V shape:                                  (latent_vector_size)
            V = (U_rel.t() @ M_i.unsqueeze(-1)).squeeze()
            I[i] = torch.linalg.solve(A, V)

        return U, I

    # Alternating between updating U and I
    for _ in range(max_iter):
        U_old, I_old = U.clone(), I.clone()
        U, I = update_factors(M, U, I, lambda_reg)

        # Check convergence
        delta = torch.norm(U - U_old) + torch.norm(I - I_old)
        if delta < tol:
            break

    return U, I


# Example usage:
import numpy as np

# Generate some random data
np.random.seed(0)
n_users = 100
n_items = 50
latent_vector_size = 10
M = np.random.randint(0, 6, size=(n_users, n_items))  # Example ratings matrix

# Convert to PyTorch tensor
M = torch.tensor(M, dtype=torch.float32)

# Run ALS
U, I = alternating_least_squares(M, latent_vector_size, 1)

# Reconstruction
M_reconstructed = U @ I.t()

# Print results
print("Original Ratings Matrix:\n", M)
print("\nReconstructed Ratings Matrix:\n", M_reconstructed)


Original Ratings Matrix:
 tensor([[4., 5., 0.,  ..., 4., 2., 0.],
        [0., 4., 5.,  ..., 0., 5., 0.],
        [1., 2., 4.,  ..., 0., 4., 0.],
        ...,
        [1., 0., 0.,  ..., 2., 2., 3.],
        [0., 4., 3.,  ..., 5., 0., 3.],
        [1., 4., 0.,  ..., 4., 3., 0.]])

Reconstructed Ratings Matrix:
 tensor([[3.9281, 4.0568, 2.8297,  ..., 2.5350, 2.4505, 3.3963],
        [1.3158, 4.2402, 3.7553,  ..., 3.4526, 3.0061, 2.7470],
        [1.4151, 2.8391, 3.0626,  ..., 3.2299, 3.8750, 2.8476],
        ...,
        [1.7969, 3.6523, 5.4935,  ..., 3.8226, 2.7363, 2.3702],
        [3.8156, 2.7519, 4.0085,  ..., 4.7571, 3.6447, 3.9910],
        [1.1499, 3.7948, 4.9516,  ..., 3.4037, 2.9100, 2.2143]])
