Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SLIMElastic #621

Merged
merged 4 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
from recbole.model.general_recommender.neumf import NeuMF
from recbole.model.general_recommender.ngcf import NGCF
from recbole.model.general_recommender.pop import Pop
from recbole.model.general_recommender.slimelastic import SLIMElastic
from recbole.model.general_recommender.spectralcf import SpectralCF
110 changes: 110 additions & 0 deletions recbole/model/general_recommender/slimelastic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
r"""
SLIMElastic
################################################
Reference:
10.1109/ICDM.2011.134
https://www.slideshare.net/MarkLevy/efficient-slides

Reference code:
https://github.com/KarypisLab/SLIM
https://github.com/MaurizioFD/RecSys2019_DeepLearning_Evaluation/blob/master/SLIM_ElasticNet/SLIMElasticNetRecommender.py
"""


from recbole.utils.enum_type import ModelType
import numpy as np
import scipy.sparse as sp
import torch
import warnings
from sklearn.linear_model import ElasticNet
from sklearn.exceptions import ConvergenceWarning

from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender


# https://github.com/RUCAIBox/RecBole/issues/622
def add_noise(t, mag=1e-5):
return t + mag * torch.rand(t.shape)


class SLIMElastic(GeneralRecommender):
input_type = InputType.POINTWISE
type = ModelType.TRADITIONAL

def __init__(self, config, dataset):
super().__init__(config, dataset)

# need at least one param
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

X = dataset.inter_matrix(
form='csr').astype(np.float32)

X = X.tolil()
self.interaction_matrix = X

hide_item = config['hide_item']
alpha = config['alpha']
l1_ratio = config['l1_ratio']
positive_only = config['positive_only']

model = ElasticNet(alpha=alpha, l1_ratio=l1_ratio,
positive=positive_only,
fit_intercept=False,
copy_X=False,
precompute=True,
selection='random',
max_iter=100,
tol=1e-4)

item_coeffs = []

# ignore ConvergenceWarnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)

for j in range(X.shape[1]):
# target column
r = X[:, j]

if hide_item:
# set item column to 0
X[:, j] = 0

# fit the model
model.fit(X, r.todense().getA1())

# store the coefficients
coeffs = model.sparse_coef_

item_coeffs.append(coeffs)

if hide_item:
# reattach column if removed
X[:, j] = r

self.item_similarity = sp.vstack(item_coeffs).T

def forward(self):
pass

def calculate_loss(self, interaction):
return torch.nn.Parameter(torch.zeros(1))

def predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()
item = interaction[self.ITEM_ID].cpu().numpy()

r = torch.from_numpy((self.interaction_matrix[user, :].multiply(
self.item_similarity[:, item].T)).sum(axis=1).getA1())

return add_noise(r)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID].cpu().numpy()

r = self.interaction_matrix[user, :] @ self.item_similarity
r = torch.from_numpy(r.todense().getA1())

return add_noise(r)
4 changes: 4 additions & 0 deletions recbole/properties/model/SLIMElastic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
alpha: 0.2
l1_ratio: 0.02
positive_only: False
hide_item: True
10 changes: 8 additions & 2 deletions run_test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@
'model': 'LINE',
'dataset': 'ml-100k',
},
'Test SLIMElastic': {
'model': 'SLIMElastic',
'dataset': 'ml-100k',
},

# Context-aware Recommendation
'Test FM': {
Expand Down Expand Up @@ -339,13 +343,15 @@ def run_test_examples():
for idx, example in enumerate(test_examples.keys()):
if example in closed_examples:
continue
print('\n\n Begin to run %d / %d example: %s \n\n' % (idx + 1, n_examples, example))
print('\n\n Begin to run %d / %d example: %s \n\n' %
(idx + 1, n_examples, example))
try:
config_dict = test_examples[example]
if 'epochs' not in config_dict:
config_dict['epochs'] = 1
run_recbole(config_dict=config_dict, saved=False)
print('\n\n Running %d / %d example successfully: %s \n\n' % (idx + 1, n_examples, example))
print('\n\n Running %d / %d example successfully: %s \n\n' %
(idx + 1, n_examples, example))
success_examples.append(example)
except Exception:
print(traceback.format_exc())
Expand Down
13 changes: 11 additions & 2 deletions tests/model/test_model_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# UPDATE
# @Time : 2020/11/17
# @Author : Xingyu Pan
# @email : panxy@ruc.edu.cn
# @email : panxy@ruc.edu.cn

import os
import unittest
Expand All @@ -16,6 +16,7 @@
current_path = os.path.dirname(os.path.realpath(__file__))
config_file_list = [os.path.join(current_path, 'test_model.yaml')]


class TestGeneralRecommender(unittest.TestCase):

def test_pop(self):
Expand Down Expand Up @@ -116,6 +117,13 @@ def test_line(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)

def test_slimelastic(self):
config_dict = {
'model': 'SLIMElastic',
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)


class TestContextRecommender(unittest.TestCase):
# todo: more complex context information should be test, such as criteo dataset
Expand Down Expand Up @@ -189,7 +197,7 @@ def test_widedeep(self):
}
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)

# def test_dcn(self):
# config_dict = {
# 'model': 'DCN',
Expand Down Expand Up @@ -760,5 +768,6 @@ def test_fdsa(self):
objective_function(config_dict=config_dict,
config_file_list=config_file_list, saved=False)


if __name__ == '__main__':
unittest.main()