Skip to content

Commit

Permalink
Added random_state option for MF algo and CoClustering to controle
Browse files Browse the repository at this point in the history
randomization.

Also moved get_rng() in utils.py
  • Loading branch information
NicolasHug committed Jan 6, 2018
1 parent 936b809 commit 9a6c673
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 110 deletions.
18 changes: 1 addition & 17 deletions surprise/model_selection/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,7 @@

import numpy as np


def get_rng(random_state):
'''Return a 'validated' RNG.
If random_state is None, use RandomState singleton from numpy. Else if
it's an integer, consider it's a seed and initialized an rng with that
seed. If it's already an rng, return it.
'''
if random_state is None:
return np.random.mtrand._rand
elif isinstance(random_state, (numbers.Integral, np.integer)):
return np.random.RandomState(random_state)
if isinstance(random_state, np.random.RandomState):
return random_state
raise ValueError('Wrong random state. Expecting None, an int or a numpy '
'RandomState instance, got a '
'{}'.format(type(random_state)))
from ..utils import get_rng


def get_cv(cv):
Expand Down
17 changes: 14 additions & 3 deletions surprise/prediction_algorithms/co_clustering.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cimport numpy as np # noqa
import numpy as np

from .algo_base import AlgoBase
from ..utils import get_rng


class CoClustering(AlgoBase):
Expand Down Expand Up @@ -41,19 +42,28 @@ class CoClustering(AlgoBase):
n_cltr_i(int): Number of item clusters. Default is ``3``.
n_epochs(int): Number of iteration of the optimization loop. Default is
``20``.
random_state(int, RandomState instance from numpy, or ``None``):
Determines the RNG that will be used for initialization. If
int, ``random_state`` will be used as a seed for a new RNG. This is
useful to get the same initialization over multiple calls to
``fit()``. If RandomState instance, this same instance is used as
RNG. If ``None``, the current RNG from numpy is used. Default is
``None``.
verbose(bool): If True, the current epoch will be printed. Default is
``False``.
"""

def __init__(self, n_cltr_u=3, n_cltr_i=3, n_epochs=20, verbose=False):
def __init__(self, n_cltr_u=3, n_cltr_i=3, n_epochs=20, random_state=None,
verbose=False):

AlgoBase.__init__(self)

self.n_cltr_u = n_cltr_u
self.n_cltr_i = n_cltr_i
self.n_epochs = n_epochs
self.verbose=verbose
self.random_state = random_state

def fit(self, trainset):

Expand All @@ -80,8 +90,9 @@ class CoClustering(AlgoBase):
cdef double est

# Randomly assign users and items to intial clusters
cltr_u = np.random.randint(self.n_cltr_u, size=trainset.n_users)
cltr_i = np.random.randint(self.n_cltr_i, size=trainset.n_items)
rng = get_rng(self.random_state)
cltr_u = rng.randint(self.n_cltr_u, size=trainset.n_users)
cltr_i = rng.randint(self.n_cltr_i, size=trainset.n_items)

# Compute user and item means
user_mean = np.zeros(self.trainset.n_users, np.double)
Expand Down
64 changes: 47 additions & 17 deletions surprise/prediction_algorithms/matrix_factorization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from six.moves import range

from .algo_base import AlgoBase
from .predictions import PredictionImpossible
from ..utils import get_rng


class SVD(AlgoBase):
Expand Down Expand Up @@ -104,14 +105,21 @@ class SVD(AlgoBase):
over ``reg_all`` if set. Default is ``None``.
reg_qi: The regularization term for :math:`q_i`. Takes precedence
over ``reg_all`` if set. Default is ``None``.
random_state(int, RandomState instance from numpy, or ``None``):
Determines the RNG that will be used for initialization. If
int, ``random_state`` will be used as a seed for a new RNG. This is
useful to get the same initialization over multiple calls to
``fit()``. If RandomState instance, this same instance is used as
RNG. If ``None``, the current RNG from numpy is used. Default is
``None``.
verbose: If ``True``, prints the current epoch. Default is ``False``.
"""

def __init__(self, n_factors=100, n_epochs=20, biased=True, init_mean=0,
init_std_dev=.1, lr_all=.005,
reg_all=.02, lr_bu=None, lr_bi=None, lr_pu=None, lr_qi=None,
reg_bu=None, reg_bi=None, reg_pu=None, reg_qi=None,
verbose=False):
random_state=None, verbose=False):

self.n_factors = n_factors
self.n_epochs = n_epochs
Expand All @@ -126,6 +134,7 @@ class SVD(AlgoBase):
self.reg_bi = reg_bi if reg_bi is not None else reg_all
self.reg_pu = reg_pu if reg_pu is not None else reg_all
self.reg_qi = reg_qi if reg_qi is not None else reg_all
self.random_state = random_state
self.verbose = verbose

AlgoBase.__init__(self)
Expand Down Expand Up @@ -192,12 +201,14 @@ class SVD(AlgoBase):
cdef double reg_pu = self.reg_pu
cdef double reg_qi = self.reg_qi

rng = get_rng(self.random_state)

bu = np.zeros(trainset.n_users, np.double)
bi = np.zeros(trainset.n_items, np.double)
pu = np.random.normal(self.init_mean, self.init_std_dev,
(trainset.n_users, self.n_factors))
qi = np.random.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))
pu = rng.normal(self.init_mean, self.init_std_dev,
(trainset.n_users, self.n_factors))
qi = rng.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))

if not self.biased:
global_mean = 0
Expand Down Expand Up @@ -322,13 +333,20 @@ class SVDpp(AlgoBase):
over ``reg_all`` if set. Default is ``None``.
reg_yj: The regularization term for :math:`y_j`. Takes precedence
over ``reg_all`` if set. Default is ``None``.
random_state(int, RandomState instance from numpy, or ``None``):
Determines the RNG that will be used for initialization. If
int, ``random_state`` will be used as a seed for a new RNG. This is
useful to get the same initialization over multiple calls to
``fit()``. If RandomState instance, this same instance is used as
RNG. If ``None``, the current RNG from numpy is used. Default is
``None``.
verbose: If ``True``, prints the current epoch. Default is ``False``.
"""

def __init__(self, n_factors=20, n_epochs=20, init_mean=0, init_std_dev=.1,
lr_all=.007, reg_all=.02, lr_bu=None, lr_bi=None, lr_pu=None,
lr_qi=None, lr_yj=None, reg_bu=None, reg_bi=None, reg_pu=None,
reg_qi=None, reg_yj=None, verbose=False):
reg_qi=None, reg_yj=None, random_state=None, verbose=False):

self.n_factors = n_factors
self.n_epochs = n_epochs
Expand All @@ -344,6 +362,7 @@ class SVDpp(AlgoBase):
self.reg_pu = reg_pu if reg_pu is not None else reg_all
self.reg_qi = reg_qi if reg_qi is not None else reg_all
self.reg_yj = reg_yj if reg_yj is not None else reg_all
self.random_state = random_state
self.verbose = verbose

AlgoBase.__init__(self)
Expand Down Expand Up @@ -386,12 +405,14 @@ class SVDpp(AlgoBase):
bu = np.zeros(trainset.n_users, np.double)
bi = np.zeros(trainset.n_items, np.double)

pu = np.random.normal(self.init_mean, self.init_std_dev,
(trainset.n_users, self.n_factors))
qi = np.random.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))
yj = np.random.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))
rng = get_rng(self.random_state)

pu = rng.normal(self.init_mean, self.init_std_dev,
(trainset.n_users, self.n_factors))
qi = rng.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))
yj = rng.normal(self.init_mean, self.init_std_dev,
(trainset.n_items, self.n_factors))
u_impl_fdb = np.zeros(self.n_factors, np.double)

for current_epoch in range(self.n_epochs):
Expand Down Expand Up @@ -527,12 +548,19 @@ class NMF(AlgoBase):
``0``.
init_high: Higher bound for random initialization of factors. Default
is ``1``.
random_state(int, RandomState instance from numpy, or ``None``):
Determines the RNG that will be used for initialization. If
int, ``random_state`` will be used as a seed for a new RNG. This is
useful to get the same initialization over multiple calls to
``fit()``. If RandomState instance, this same instance is used as
RNG. If ``None``, the current RNG from numpy is used. Default is
``None``.
verbose: If ``True``, prints the current epoch. Default is ``False``.
"""

def __init__(self, n_factors=15, n_epochs=50, biased=False, reg_pu=.06,
reg_qi=.06, reg_bu=.02, reg_bi=.02, lr_bu=.005, lr_bi=.005,
init_low=0, init_high=1, verbose=False):
init_low=0, init_high=1, random_state=None, verbose=False):

self.n_factors = n_factors
self.n_epochs = n_epochs
Expand All @@ -545,6 +573,7 @@ class NMF(AlgoBase):
self.reg_bi = reg_bi
self.init_low = init_low
self.init_high = init_high
self.random_state = random_state
self.verbose = verbose

if self.init_low < 0:
Expand Down Expand Up @@ -584,10 +613,11 @@ class NMF(AlgoBase):
cdef double global_mean = self.trainset.global_mean

# Randomly initialize user and item factors
pu = np.random.uniform(self.init_low, self.init_high,
size=(trainset.n_users, self.n_factors))
qi = np.random.uniform(self.init_low, self.init_high,
size=(trainset.n_items, self.n_factors))
rng = get_rng(self.random_state)
pu = rng.uniform(self.init_low, self.init_high,
size=(trainset.n_users, self.n_factors))
qi = rng.uniform(self.init_low, self.init_high,
size=(trainset.n_items, self.n_factors))

bu = np.zeros(trainset.n_users, np.double)
bi = np.zeros(trainset.n_items, np.double)
Expand Down
25 changes: 25 additions & 0 deletions surprise/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
'''The utils module contains the get_rng function.'''

from __future__ import (absolute_import, division, print_function,
unicode_literals)
import numbers

import numpy as np


def get_rng(random_state):
'''Return a 'validated' RNG.
If random_state is None, use RandomState singleton from numpy. Else if
it's an integer, consider it's a seed and initialized an rng with that
seed. If it's already an rng, return it.
'''
if random_state is None:
return np.random.mtrand._rand
elif isinstance(random_state, (numbers.Integral, np.integer)):
return np.random.RandomState(random_state)
if isinstance(random_state, np.random.RandomState):
return random_state
raise ValueError('Wrong random state. Expecting None, an int or a numpy '
'RandomState instance, got a '
'{}'.format(type(random_state)))
26 changes: 13 additions & 13 deletions tests/test_NMF.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,64 +25,64 @@ def test_NMF_parameters():
"""Ensure that all parameters are taken into account."""

# The baseline against which to compare.
algo = NMF(n_factors=1, n_epochs=1)
algo = NMF(n_factors=1, n_epochs=1, random_state=1)
rmse_default = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']

# n_factors
algo = NMF(n_factors=2, n_epochs=1)
algo = NMF(n_factors=2, n_epochs=1, random_state=1)
rmse_factors = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_factors

# n_epochs
algo = NMF(n_factors=1, n_epochs=2)
algo = NMF(n_factors=1, n_epochs=2, random_state=1)
rmse_n_epochs = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_n_epochs

# biased
algo = NMF(n_factors=1, n_epochs=1, biased=True)
algo = NMF(n_factors=1, n_epochs=1, biased=True, random_state=1)
rmse_biased = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_biased

# reg_pu
algo = NMF(n_factors=1, n_epochs=1, reg_pu=1)
algo = NMF(n_factors=1, n_epochs=1, reg_pu=1, random_state=1)
rmse_reg_pu = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_pu

# reg_qi
algo = NMF(n_factors=1, n_epochs=1, reg_qi=1)
algo = NMF(n_factors=1, n_epochs=1, reg_qi=1, random_state=1)
rmse_reg_qi = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_qi

# reg_bu
algo = NMF(n_factors=1, n_epochs=1, reg_bu=1, biased=True)
algo = NMF(n_factors=1, n_epochs=1, reg_bu=1, biased=True, random_state=1)
rmse_reg_bu = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bu

# reg_bi
algo = NMF(n_factors=1, n_epochs=1, reg_bi=1, biased=True)
algo = NMF(n_factors=1, n_epochs=1, reg_bi=1, biased=True, random_state=1)
rmse_reg_bi = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bi

# lr_bu
algo = NMF(n_factors=1, n_epochs=1, lr_bu=1, biased=True)
algo = NMF(n_factors=1, n_epochs=1, lr_bu=1, biased=True, random_state=1)
rmse_lr_bu = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bu

# lr_bi
algo = NMF(n_factors=1, n_epochs=1, lr_bi=1, biased=True)
algo = NMF(n_factors=1, n_epochs=1, lr_bi=1, biased=True, random_state=1)
rmse_lr_bi = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bi

# init_low
algo = NMF(n_factors=1, n_epochs=1, init_low=.5)
algo = NMF(n_factors=1, n_epochs=1, init_low=.5, random_state=1)
rmse_init_low = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_init_low

# init_low
with pytest.raises(ValueError):
algo = NMF(n_factors=1, n_epochs=1, init_low=-1)
algo = NMF(n_factors=1, n_epochs=1, init_low=-1, random_state=1)

# init_high
algo = NMF(n_factors=1, n_epochs=1, init_high=.5)
algo = NMF(n_factors=1, n_epochs=1, init_high=.5, random_state=1)
rmse_init_high = cross_validate(algo, data, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_init_high

0 comments on commit 9a6c673

Please sign in to comment.