Skip to content

Commit

Permalink
EASEᴿ Implementation (#477)
Browse files Browse the repository at this point in the history
* EASEᴿ Implementation

* sorting corrected

* corrections

* Update __init__.py

* Update README.md

Co-authored-by: Quoc-Tuan Truong <tqtg@users.noreply.github.com>
  • Loading branch information
yilmazerhakan and tqtg committed Jun 17, 2022
1 parent b300716 commit 3f7a04a
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
| | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py)
| | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER)
| 2020 | [Adversarial Training Towards Robust Multimedia Recommender System (AMR)](cornac/models/amr), [paper](https://ieeexplore.ieee.org/document/8618394) | [requirements.txt](cornac/models/amr/requirements.txt) | [amr_clothing.py](examples/amr_clothing.py)
| 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py)
| 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py)
| | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py)
| | [Neural Attention Rating Regression with Review-level Explanations (NARRE)](cornac/models/narre), [paper](http://www.thuir.cn/group/~YQLiu/publications/WWW2018_CC.pdf) | [requirements.txt](cornac/models/narre/requirements.txt) | [narre_example.py](examples/narre_example.py)
Expand Down Expand Up @@ -153,6 +154,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
| | [User K-Nearest-Neighbors (UserKNN)](cornac/models/knn), [paper](https://arxiv.org/pdf/1301.7363.pdf) | N/A | [knn_movielens.py](examples/knn_movielens.py)
| | [Weighted Matrix Factorization (WMF)](cornac/models/wmf), [paper](http://yifanhu.net/PUB/cf.pdf) | [requirements.txt](cornac/models/wmf/requirements.txt) | [wmf_exp.py](examples/wmf_example.py)


## Support

Your contributions at any level of the library are welcome. If you intend to contribute, please:
Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .ctr import CTR
from .cvae import CVAE
from .cvaecf import CVAECF
from .ease import EASE
from .efm import EFM
from .global_avg import GlobalAvg
from .hft import HFT
Expand Down
1 change: 1 addition & 0 deletions cornac/models/ease/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .recom_ease import EASE
135 changes: 135 additions & 0 deletions cornac/models/ease/recom_ease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np

from cornac.models.recommender import Recommender
from cornac.exception import ScoreException

class EASE(Recommender):
"""Embarrassingly Shallow Autoencoders for Sparse Data.
Parameters
----------
name: string, optional, default: 'EASEᴿ'
The name of the recommender model.
lamb: float, optional, default: 500
L2-norm regularization-parameter λ ∈ R+.
posB: boolean, optional, default: False
Remove Negative Weights
trainable: boolean, optional, default: True
When False, the model is not trained and Cornac assumes that the model is already \
trained.
verbose: boolean, optional, default: False
When True, some running logs are displayed.
seed: int, optional, default: None
Random seed for parameters initialization.
References
----------
* Steck, H. (2019, May). "Embarrassingly shallow autoencoders for sparse data." \
In The World Wide Web Conference (pp. 3251-3257).
"""

def __init__(
self,
name="EASEᴿ",
lamb=500,
posB=True,
trainable=True,
verbose=True,
seed=None,
B=None,
U=None,
):
Recommender.__init__(self, name=name, trainable=trainable, verbose=verbose)
self.lamb = lamb
self.posB = posB
self.verbose = verbose
self.seed = seed
self.B = B
self.U = U

def fit(self, train_set, val_set=None):
"""Fit the model to observations.
Parameters
----------
train_set: :obj:`cornac.data.Dataset`, required
User-Item preference data as well as additional modalities.
val_set: :obj:`cornac.data.Dataset`, optional, default: None
User-Item preference data for model selection purposes (e.g., early stopping).
Returns
-------
self : object
"""
Recommender.fit(self, train_set, val_set)

# A rating matrix
self.U = self.train_set.matrix

# Gram matrix is X^t X, compute dot product
G = self.U.T.dot(self.U).toarray()

diag_indices = np.diag_indices(G.shape[0])

G[diag_indices] = G.diagonal() + self.lamb

P = np.linalg.inv(G)

B = P / (-np.diag(P))

B[diag_indices] = 0.0

# if self.posB remove -ve values
if self.posB:
B[B<0]=0

# save B for predictions
self.B=B

return self


def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.
Parameters
----------
user_idx: int, required
The index of the user for whom to perform score prediction.
item_idx: int, optional, default: None
The index of the item for which to perform score prediction.
If None, scores for all known items will be returned.
Returns
-------
res : A scalar or a Numpy array
Relative scores that the user gives to the item or to all known items
"""
if item_idx is None:
if self.train_set.is_unk_user(user_idx):
raise ScoreException(
"Can't make score prediction for (user_id=%d)" % user_idx
)

known_item_scores = self.U[user_idx, :].dot(self.B)
return known_item_scores
else:
if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item(
item_idx
):
raise ScoreException(
"Can't make score prediction for (user_id=%d, item_id=%d)"
% (user_idx, item_idx)
)

user_pred = self.B[item_idx, :].dot(self.U[user_idx, :])

return user_pred
7 changes: 6 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ Adversarial Training Towards Robust Multimedia Recommender System (AMR)
.. automodule:: cornac.models.amr.recom_amr
:members:

Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)
--------------------------------------------------
.. automodule:: cornac.models.ease.recom_ease
:members:

Collaborative Context Poisson Factorization (C2PF)
----------------------------------------------------
.. automodule:: cornac.models.c2pf.recom_c2pf
Expand Down Expand Up @@ -231,4 +236,4 @@ User K-Nearest-Neighbors (UserKNN)
Weighted Matrix Factorization (WMF)
--------------------------------------------------
.. automodule:: cornac.models.wmf.recom_wmf
:members:
:members:
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

[bpr_netflix.py](bpr_netflix.py) - Example to run Bayesian Personalized Ranking (BPR) with Netflix dataset.

[ease_movielens.py](ease_movielens.py) - Embarrassingly Shallow Autoencoders (EASEᴿ) with MovieLens 1M dataset.

[fm_example.py](fm_example.py) - Example to run Factorization Machines (FM) with MovieLens 100K dataset.

[hpf_movielens.py](hpf_movielens.py) - (Hierarchical) Poisson Factorization vs BPR on MovieLens data.
Expand Down
61 changes: 61 additions & 0 deletions examples/ease_movielens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2018 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Example (EASEᴿ) Embarrassingly Shallow Autoencoders for Sparse Data on MovieLens data"""

import cornac
from cornac.datasets import movielens
from cornac.eval_methods import RatioSplit


# Load user-item feedback
data = movielens.load_feedback(variant="1M")

# Instantiate an evaluation method to split data into train and test sets.
ratio_split = RatioSplit(
data=data,
test_size=0.2,
exclude_unknowns=True,
verbose=True,
seed=123,
rating_threshold=0.8,
)

ease_original = cornac.models.EASE(
lamb=500,
name="EASEᴿ (B>0)",
posB=True
)

ease_all = cornac.models.EASE(
lamb=500,
name="EASEᴿ (B>-∞)",
posB=False
)


# Instantiate evaluation measures
rec_20 = cornac.metrics.Recall(k=20)
rec_50 = cornac.metrics.Recall(k=50)
ndcg_100 = cornac.metrics.NDCG(k=100)


# Put everything together into an experiment and run it
cornac.Experiment(
eval_method=ratio_split,
models=[ease_original, ease_all],
metrics=[rec_20, rec_50, ndcg_100],
user_based=True, #If `False`, results will be averaged over the number of ratings.
save_dir=None
).run()

0 comments on commit 3f7a04a

Please sign in to comment.