# LFM梯度下降算法实现

### 0. 引入依赖

In [1]:
import numpy as np
import pandas as pd

### 1. 数据准备

In [2]:
# 评分矩阵R
R = np.array([[4,0,2,0,1],
             [0,2,3,0,0],
             [1,0,2,4,0],
             [5,0,0,3,1],
             [0,0,1,5,1],
             [0,3,2,4,1],])
len(R[0])

5

### 2. 算法实现

In [3]:
"""
@输入参数：
R：M*N 的评分矩阵
K：隐特征向量维度
max_iter: 最大迭代次数
alpha：步长
lamda：正则化系数

@输出：
分解之后的 P，Q
P：初始化用户特征矩阵M*K
Q：初始化物品特征矩阵N*K
"""

# 给定超参数

K = 5
max_iter = 5000
alpha = 0.0002
lamda = 0.004

# 核心算法
def LFM_grad_desc( R, K=2, max_iter=1000, alpha=0.0001, lamda=0.002 ):
    # 基本维度参数定义
    M = len(R)
    N = len(R[0])
    
    # P,Q初始值，随机生成
    P = np.random.rand(M, K)
    Q = np.random.rand(N, K)
    Q = Q.T
    
    # 开始迭代
    for step in range(max_iter):
        # 对所有的用户u、物品i做遍历，对应的特征向量Pu、Qi梯度下降
        for u in range(M):
            for i in range(N):
                # 对于每一个大于0的评分，求出预测评分误差
                if R[u][i] > 0:
                    eui = np.dot( P[u,:], Q[:,i] ) - R[u][i]
                    
                    # 代入公式，按照梯度下降算法更新当前的Pu、Qi
                    for k in range(K):
                        P[u][k] = P[u][k] - alpha * ( 2 * eui * Q[k][i] + 2 * lamda * P[u][k] )
                        Q[k][i] = Q[k][i] - alpha * ( 2 * eui * P[u][k] + 2 * lamda * Q[k][i] )
        
        # u、i遍历完成，所有特征向量更新完成，可以得到P、Q，可以计算预测评分矩阵
        predR = np.dot( P, Q )
        
        # 计算当前损失函数
        cost = 0
        for u in range(M):
            for i in range(N):
                if R[u][i] > 0:
                    cost += ( np.dot( P[u,:], Q[:,i] ) - R[u][i] ) ** 2
                    # 加上正则化项
                    for k in range(K):
                        cost += lamda * ( P[u][k] ** 2 + Q[k][i] ** 2 )
        if cost < 0.0001:
            break
        
    return P, Q.T, cost

### 3. 测试

In [4]:
P, Q, cost = LFM_grad_desc(R, K, max_iter, alpha, lamda)

print(P)
print(Q)
print(cost)

predR = P.dot(Q.T)

print(R)
predR

[[ 1.20228712  0.2163174   0.7221636   0.42297014  0.96337102]
 [ 0.64706519  0.37203219  1.22170939  0.81407379  1.32787257]
 [ 0.75393526  0.67937964 -0.19127409  1.47400963  0.43731798]
 [ 1.00876416 -0.16999092  1.3076031   0.22905326  1.13485348]
 [ 0.95527577  1.25781822  0.77994284  0.93501153 -0.32947453]
 [ 1.52664365  0.18274245  1.01764047  0.35637914  0.69029207]]
[[ 1.20685233 -0.19093666  1.68072371 -0.02144532  1.38281313]
 [ 1.40886966  0.40580216  0.49848527  0.34758071  0.08367186]
 [ 0.2806506  -0.04990035  0.46448909  0.91850157  1.09114708]
 [ 1.47507932  1.32235372  0.85917103  1.39159692  0.24621947]
 [ 0.29031004  0.37331968  0.27010392  0.1400478   0.31066989]]
0.5813731297560372
[[4 0 2 0 1]
 [0 2 3 0 0]
 [1 0 2 4 0]
 [5 0 0 3 1]
 [0 0 1 5 1]
 [0 3 2 4 1]]


array([[3.94652892, 2.36925912, 2.1017436 , 3.50578366, 0.98337718],
       [4.58197503, 2.06566802, 2.9271377 , 3.9558986 , 1.18326421],
       [1.03180958, 1.79147143, 1.91990475, 4.0050608 , 0.76313021],
       [5.01198457, 2.1786252 , 2.34763809, 2.98484657, 0.96722533],
       [1.74792828, 2.54249864, 1.06691247, 4.96252856, 0.98614854],
       [4.46481594, 2.9139065 , 1.97256241, 4.0337973 , 1.05065307]])