In [1]:
import numpy as np

In [2]:
class MatrixFactorization(object):
    def __init__(self, data, k):
        '''
        Arguments:
        - data    : 2 dimensional rating matrix
        - k       : number of latent dimensions
        '''
        
        self.R = np.matrix(data)
        self.D = np.zeros( self.R.shape )
        self.k = k
        
        self.U = 2*(np.random.uniform( size=(self.R.shape[0], k) )-.5)
        self.P = 2*(np.random.uniform( size=(k, self.R.shape[1]) )-.5)
    
    def _compure_error(self):
        self.D = (self.R - self.estimate_all())
        
        return self.D
    
    def train(self, alpha=0.1, beta=0.01, iterations=1000):
        '''
        Arguments:
        - alpha   : learning-rate 
        - beta    : regularization-rate
        '''
        
        for _ in range(iterations):
            self._compure_error()
            
            U = self.U
            P = self.P
            
            for i in range(self.R.shape[0]):      
                for j in range(self.R.shape[1]):
                    for k in range(self.k):
                        ik = (alpha/self.k) * P[k, j] * self.D[i, j] 
                        kj = (alpha/self.k) * U[i, k] * self.D[i, j]
                        if not np.isnan(ik):
                            self.U[i, k] += ik
                        if not np.isnan(kj):
                            self.P[k, j] += kj
    
    def estimate_all(self):
        return self.U.dot(self.P)
    
    def estimate(self, x, y):
        return self.U[x, :].dot(self.P[:, y])

In [19]:
R = np.array([
    [np.NaN, 3, 0, 1],
    [4, 0, 0, 1],
    [1, 1, 0, 5],
    [1, 0, 0, 4],
    [0, 1, 5, 4],
])

In [20]:
MF = MatrixFactorization(data=R, k=2)

In [21]:
MF.train(alpha=0.01, iterations=100000)

In [22]:
MF.estimate_all()

array([[  6.37665854e+01,   2.97687853e+00,  -7.64611859e-02,
          1.04013556e+00],
       [  3.98951939e+00,   3.03387274e-01,   3.40742477e-01,
          7.87954360e-01],
       [  1.02964996e+00,   7.11692509e-01,   1.95623886e+00,
          4.11204517e+00],
       [  9.88508505e-01,   5.61066461e-01,   1.51765618e+00,
          3.19370766e+00],
       [ -1.51235605e-02,   8.42020795e-01,   2.48578503e+00,
          5.20025115e+00]])