Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ Markdown = "~3.2"

torch = {version = "^1.6", optional = true}
pytorch-lightning = {version = "^1.6", optional = true}
recommenders = {version = ">=1.1.0", optional = true}

[tool.poetry.extras]
recders = ["recommenders"]
nn = ["torch", "pytorch-lightning"]
all = ["torch", "pytorch-lightning"]
all = ["torch", "pytorch-lightning", "recommenders"]

[tool.poetry.dev-dependencies]
black = "22.3.0"
Expand Down
14 changes: 14 additions & 0 deletions rectools/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,21 @@ def __new__(cls, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
f"run `pip install rectools[nn]` to install extra requirements before accessing {cls.__name__} "
f"(see `extras/requirements-nn.txt)"
)

class RecommendersModelUnavailable:
"""Dummy class the instance of which is returned in case a model provided lacks any libraries required"""

def __new__(cls, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
"""Raise ImportError when an attempt to instantiate an unavailable model is made"""
raise ImportError(
f"Cannot initialize {cls.__name__}: "
f"run `pip install rectools[recders]` to install extra requirements before accessing {cls.__name__} "
f"(see `extras/requirements-nn.txt)"
)


class DSSMModel(NNModelUnavailable):
"""Dummy class the instance of which is returned in case DSSMModel lacks any libraries required"""

class SarWrapper(RecommendersModelUnavailable):
"""Dummy class the instance of which is returned in case SarWrapper lacks any libraries required"""
5 changes: 5 additions & 0 deletions rectools/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
`models.PopularInCategoryModel`
`models.PureSVDModel`
`models.RandomModel`
`models.SarWrapper`
"""

from .implicit_als import ImplicitALSWrapperModel
Expand All @@ -42,6 +43,10 @@
from .popular_in_category import PopularInCategoryModel
from .pure_svd import PureSVDModel
from .random import RandomModel
try:
from .sar import SarWrapper
except ImportError:
from ..compat import SarWrapper

try:
from .dssm import DSSMModel
Expand Down
102 changes: 102 additions & 0 deletions rectools/models/sar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp
import pandas as pd
import numpy as np
from recommenders.models.sar import SAR
from rectools import ExternalIds



from rectools.dataset import Dataset
from rectools.exceptions import NotFittedError
from rectools import Columns, ExternalIds, InternalIds
from rectools.models.base import ModelBase, Scores

class SarWrapper(ModelBase) :
"""
Simple Algorithm for Recommendations (SAR) implementation

SAR is a fast scalable adaptive algorithm for personalized recommendations based on user transaction history
and items description. The core idea behind SAR is to recommend items like those that a user already has
demonstrated an affinity to. It does this by 1) estimating the affinity of users for items, 2) estimating
similarity across items, and then 3) combining the estimates to generate a set of recommendations for a given user.

!!! Can't recomend items from other dataset except from original (sorted_item_ids_to_recommend does't do anything) !!!

Parameters
----------
time_decay_coefficient : float, default 30
Number of days till ratings are decayed by 1/2
time_now int | None, default None
Current time for time decay calculation
timedecay_formula : bool, default False
Flag to apply time decay
"""
def __init__(self, time_decay_coefficient=30, time_now=None, timedecay_formula=False) :
self.is_fitted = False
self._model = SAR(
col_user=Columns.User,
col_item=Columns.Item,
col_rating=Columns.Weight,
col_timestamp=Columns.Datetime,
col_prediction=Columns.Score,
time_decay_coefficient=time_decay_coefficient,
time_now=time_now,
timedecay_formula=timedecay_formula,
normalize=True
)
def _fit(self, dataset : Dataset) -> None :
self.is_fitted = True
self._model.fit(dataset.interactions.df)

def _recommend_u2i(
self,
user_ids: np.ndarray,
dataset: Dataset,
k: int,
filter_viewed: bool,
sorted_item_ids_to_recommend: tp.Optional[np.ndarray],
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
df = dataset.interactions.df

result = self._model.recommend_k_items(df[df[Columns.User].isin(user_ids.tolist())], top_k=k, remove_seen=filter_viewed)

return [result[Columns.User].to_numpy(), result[Columns.Item].to_numpy(), result[Columns.Score].to_numpy()]

def _recommend_i2i(
self,
target_ids: np.ndarray,
dataset: Dataset,
k: int,
sorted_item_ids_to_recommend: tp.Optional[np.ndarray],
) -> tp.Tuple[InternalIds, InternalIds, Scores]:
res_items = pd.DataFrame(
{
Columns.TargetItem,
Columns.Item,
Columns.Score
}
)

df = dataset.interactions.df

for i in target_ids :
tmp = self._model.get_item_based_topk(pd.DataFrame(data={Columns.Item : [i]}), top_k=k)
tmp.drop(Columns.User)
tmp[Columns.TargetItem] = i
res_items.append(tmp)

return [res_items[Columns.TargetItem].to_numpy(), res_items[Columns.Item].to_numpy(), res_items[Columns.Score].to_numpy()]
91 changes: 91 additions & 0 deletions tests/models/test_sar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import pytest

from rectools import Columns
from rectools.dataset import Dataset
from rectools.models.sar import SarWrapper
from .data import DATASET

class TestSarWrapper :
@pytest.fixture
def dataset(self) -> Dataset:
return DATASET

@pytest.mark.parametrize(
"filter_viewed,expected",
(
(
False,
pd.DataFrame(
{
Columns.User: [10, 10, 20, 20],
Columns.Item: [11, 17, 11, 17]
}
)
),
(
True,
pd.DataFrame(
{
Columns.User: [10, 10, 20, 20],
Columns.Item: [13, 17, 15, 17]
}
)
)
)
)
def test_recomend(
self,
dataset: Dataset,
filter_viewed: bool,
expected: pd.DataFrame) -> None:
sar = SarWrapper()
sar.fit(dataset)
actual = sar.recommend(
users=np.array([10, 20]),
dataset=dataset,
k=2,
filter_viewed=filter_viewed,
add_rank_col=False
)
pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected)

@pytest.mark.parametrize(
"expected",
(
(
pd.DataFrame(
{
Columns.TargetItem: [11, 11, 12, 12],
Columns.Item: [12, 14, 11, 14]
}
)
)
)
)
def test_i2i(
self, dataset: Dataset, expected: pd.DataFrame) -> None:
model = SarWrapper().fit(dataset)
actual = model.recommend_to_items(
target_items=np.array([11, 12]),
dataset=dataset,
k=2,
filter_itself=False,
add_rank_col=False
)
pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected)