Skip to content

Commit

Permalink
Added the Spherical k-means based recommender
Browse files Browse the repository at this point in the history
  • Loading branch information
Aghiles SALAH committed Oct 14, 2018
1 parent fe7618c commit f0c9705
Show file tree
Hide file tree
Showing 18 changed files with 591 additions and 4 deletions.
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .pcrl import *
from .pmf import *
from .ibpr import *
from .skm import *
2 changes: 1 addition & 1 deletion cornac/models/hpf/recom_hpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Hpf(Recommender):
The name of the recommender model.
trainable: boolean, optional, default: True
When False, the model is not trained and Cornac assumes that the model already \
When False, the model is not trained and Cornac assumes that the model is already \
pre-trained (Theta and Beta are not None).
init_params: dictionary, optional, default: None
Expand Down
5 changes: 5 additions & 0 deletions cornac/models/skm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from .recom_skmeans import Skmeans


__all__ = ['Skmeans']
106 changes: 106 additions & 0 deletions cornac/models/skm/recom_skmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
@author: Aghiles Salah <asalah@smu.edu.sg>
"""

import numpy as np
from ..recommender import Recommender
from .skmeans import *


class Skmeans(Recommender):
"""Spherical k-means based recommender.
Parameters
----------
k: int, optional, default: 5
The number of clusters.
max_iter: int, optional, default: 100
Maximum number of iterations.
name: string, optional, default: 'Skmeans'
The name of the recommender model.
trainable: boolean, optional, default: True
When False, the model is not trained and Cornac assumes that the model is already \
pre-trained (Theta and Beta are not None).
tol : float, optional, default: 1e-6
Relative tolerance with regards to skmeans' criterion to declare convergence.
verbose: boolean, optional, default: True
When True, the skmeans criterion (likelihood) is displayed after each iteration.
init_par: numpy 1d array, optional, default: None
The initial object parition, 1d array contaning the cluster label (int type starting from 0) \
of each object (user). If par = None, then skmeans is initialized randomly.
centroids: csc_matrix, shape (k,n_users)
The maxtrix of cluster centroids.
References
----------
* Salah, Aghiles, Nicoleta Rogovschi, and Mohamed Nadif. "A dynamic collaborative filtering system \
via a weighted clustering approach." Neurocomputing 175 (2016): 206-215.
"""

def __init__(self, k=5, max_iter=100, name = "Skmeans",trainable = True, tol = 1e-6, verbose = True, init_par = None):
Recommender.__init__(self,name=name, trainable = trainable)
self.k = k
self.par = init_par
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose

self.centroids = None #matrix of cluster centroids


#fit the recommender model to the traning data
def fit(self,X):
"""Fit the model to observations.
Parameters
----------
X: scipy sparse matrix, required
the user-item preference matrix (traning data), in a scipy sparse format\
(e.g., csc_matrix).
"""
if self.trainable:
#Skmeans requires rows of X to have a unit L2 norm. We therefore need to make a copy of X as we should not modify the latter.
X1 = X.copy()
X1 = X1.multiply(sp.csc_matrix(1./(np.sqrt(X1.multiply(X1).sum(1).A1)+1e-20)).T)
res = skmeans(X1,k = self.k, max_iter = self.max_iter,tol = self.tol,verbose = self.verbose,init_par = self.par)
self.centroids = res['centroids']
self.par = res['partition']
self.user_center_sim = X1*self.centroids.T #user-centroid cosine similarity matrix
del(X1)
else:
print('%s is trained already (trainable = False)' % (self.name))




#get prefiction for a single user (predictions for one user at a time for efficiency purposes)
#predictions are not stored for the same efficiency reasons"""
def predict(self,index_user):
"""Predic the scores (ratings) of a user for all items.
Parameters
----------
index_user: int, required
The index of the user for whom to perform predictions.
Returns
-------
Numpy 1d array
Array containing the predicted values for all items
"""
user_pred = self.centroids.multiply(self.user_center_sim[index_user,:].T)
#transform user_pred to a flatten array
user_pred = user_pred.sum(0).A1/(self.user_center_sim[index_user,:].sum()+1e-20) #weighted average of cluster centroids

return user_pred




63 changes: 63 additions & 0 deletions cornac/models/skm/skmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""
@author: Aghiles Salah
"""

import numpy as np
import scipy.sparse as sp


def skmeans(X,k=5, max_iter = 100,tol = 1e-6, verbose = True,init_par = None):
#The Spherical k-means clustering algorithm

n = X.shape[0]
#normalize rows of X so as they lie on a unit hypersphere
X = X.multiply(sp.csc_matrix(1./(np.sqrt(X.multiply(X).sum(1).A1)+1e-20)).T)
if init_par is None:
par= np.random.randint(k,size=n)
else:
par = init_par

#Initialisation of the classification matrix Z
Z=sp.csc_matrix((n,k))
Z[np.arange(n),par]=1


change = True
l_init= -1e1000
l=[]
iter_ = 0
while change :
change=False
#Update centroids
MU=Z.T*X
#project centroids to the unit hypersphere
MU = MU.multiply(sp.csc_matrix(1./np.sqrt(MU.multiply(MU).sum(1).A1)).T)
#MU = sp.csc_matrix(MU)

#Object Assignements
Z1=X*MU.T
par = Z1.argmax(1).A1 #The object partition in k clusters
#update the classification matrix
Z=sp.csc_matrix((n,k))
Z[np.arange(len(par)), par] = 1


#Skmeans criteria (likelihood)
l_t = Z1.multiply(Z).sum()

if np.abs(l_t - l_init) > tol:
if verbose:
print('Iter %i, likelihood: %f' % (iter_, l_t))
l_init=l_t
change=True
l.append(l_t)
iter_+=1

return {"centroids": MU, "partition": par}






Binary file modified dist/cornac-0.1.0-cp36-cp36m-win_amd64.whl
Binary file not shown.
Binary file modified docs/build/html/.doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/html/.doctrees/models.doctree
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/build/html/_modules/cornac/models/hpf/recom_hpf.html
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ <h1>Source code for cornac.models.hpf.recom_hpf</h1><div class="highlight"><pre>
<span class="sd"> The name of the recommender model.</span>

<span class="sd"> trainable: boolean, optional, default: True</span>
<span class="sd"> When False, the model is not trained and Cornac assumes that the model already \</span>
<span class="sd"> When False, the model is not trained and Cornac assumes that the model is already \</span>
<span class="sd"> pre-trained (Theta and Beta are not None). </span>

<span class="sd"> init_params: dictionary, optional, default: None</span>
Expand Down

0 comments on commit f0c9705

Please sign in to comment.