-
Notifications
You must be signed in to change notification settings - Fork 54
Added intersection metric (#95) #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
blondered
merged 12 commits into
MTSWebServices:main
from
azatnv:feature/intersecton_metric
May 31, 2024
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
6e9ae03
Added intersection metric (#95)
azatnv f421bd6
Intersection metric behavior change (#95)
azatnv c9b217d
Added PR link to change log (#95)
azatnv b6e96b7
Updated metric behavior and docs (#95)
azatnv e374bd2
Added test for empty ref_reco (#95)
azatnv 974a750
Added intersection metric in calc_metrics (#95)
azatnv 7dc37b3
Changing variable name "ref_recos" to "ref_reco" (#95)
azatnv 7b3ef24
Added support intersection metric in cross_validate (#95)
azatnv 3f249d1
Changed behaviour cross_validate. Added test (#95)
azatnv e3ee177
Calculate intersection for all models (#95)
azatnv 8e56ae7
Change doc (#95)
azatnv ff8bdf6
Merge branch 'main' into feature/intersecton_metric
azatnv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| from typing import Dict, Hashable, Optional, Union | ||
|
|
||
| import attr | ||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from rectools import Columns | ||
| from rectools.metrics.base import MetricAtK | ||
| from rectools.metrics.classification import Recall | ||
| from rectools.utils import select_by_type | ||
|
|
||
|
|
||
| @attr.s(auto_attribs=True) | ||
| class Intersection(MetricAtK): | ||
| """ | ||
| Metric to measure intersection in user-item pairs between recommendation lists. | ||
|
|
||
| The intersection@k equals the share of ``reco`` that is present in ``ref_reco``. | ||
|
|
||
| This corresponds to the following algorithm: | ||
| 1) filter ``reco`` by ``k`` | ||
| 2) filter ``ref_reco`` by ``ref_k`` | ||
| 3) calculate the proportion of items in ``reco`` that are also present in ``ref_reco`` | ||
| The second and third steps are equivalent to computing Recall@ref_k when: | ||
| - Interactions consists of ``reco`` without the `Columns.Rank` column. | ||
| - Recommendation table is ``ref_reco`` | ||
|
|
||
| Parameters | ||
| ---------- | ||
| k : int | ||
| Number of items in top of recommendations list that will be used to calculate metric. | ||
| ref_k : int, optional | ||
| Number of items in top of reference recommendations list that will be used to calculate metric. | ||
| If ``ref_k`` is None than ``ref_reco`` will be filtered with ``ref_k = k``. Default: None. | ||
| """ | ||
|
|
||
| ref_k: Optional[int] = attr.ib(default=None) | ||
|
|
||
| def calc(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> float: | ||
| """ | ||
| Calculate metric value. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| reco : pd.DataFrame | ||
| Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
| ref_reco : pd.DataFrame | ||
| Reference recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
|
|
||
| Returns | ||
| ------- | ||
| float | ||
| Value of metric (average between users). | ||
| """ | ||
| per_user = self.calc_per_user(reco, ref_reco) | ||
| return per_user.mean() | ||
|
|
||
| def calc_per_user(self, reco: pd.DataFrame, ref_reco: pd.DataFrame) -> pd.Series: | ||
| """ | ||
| Calculate metric values for all users. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| reco : pd.DataFrame | ||
| Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
| ref_reco : pd.DataFrame | ||
| Reference recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
|
|
||
| Returns | ||
| ------- | ||
| pd.Series: | ||
| Values of metric (index - user id, values - metric value for every user). | ||
| """ | ||
| self._check(reco) | ||
| assert set(ref_reco.columns) >= {Columns.User, Columns.Item, Columns.Rank} | ||
|
|
||
| if ref_reco.shape[0] == 0: | ||
| return pd.Series(index=pd.Series(name=Columns.User, dtype=int), dtype=np.float64) | ||
|
|
||
| if ref_reco is reco: | ||
| return pd.Series( | ||
| data=1, | ||
| index=pd.Series(data=reco[Columns.User].unique(), name=Columns.User, dtype=int), | ||
| dtype=np.float64, | ||
| ) | ||
|
|
||
| filtered_reco = reco[reco[Columns.Rank] <= self.k] | ||
|
|
||
| if self.ref_k is None: | ||
| self.ref_k = self.k | ||
| recall = Recall(k=self.ref_k) | ||
|
|
||
| return recall.calc_per_user(ref_reco, filtered_reco[Columns.UserItem]) | ||
|
|
||
|
|
||
| IntersectionMetric = Intersection | ||
|
|
||
|
|
||
| def calc_intersection_metrics( | ||
| metrics: Dict[str, IntersectionMetric], | ||
| reco: pd.DataFrame, | ||
| ref_reco: Union[pd.DataFrame, Dict[Hashable, pd.DataFrame]], | ||
| ) -> Dict[str, float]: | ||
| """ | ||
| Calculate intersection metrics. | ||
|
|
||
| Warning: It is not recommended to use this function directly. | ||
| Use `calc_metrics` instead. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| metrics : dict(str -> IntersectionMetric) | ||
| Dict of metric objects to calculate, | ||
| where key is metric name and value is metric object. | ||
| reco : pd.DataFrame | ||
| Recommendations table with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
| ref_reco : Union[pd.DataFrame, Dict[Hashable, pd.DataFrame]] | ||
| Reference recommendations table(s) with columns `Columns.User`, `Columns.Item`, `Columns.Rank`. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict(str->float) | ||
| Dictionary where keys are the same as keys in `metrics` | ||
| and values are metric calculation results. | ||
| """ | ||
| results = {} | ||
|
|
||
| intersection_metrics: Dict[str, Intersection] = select_by_type(metrics, Intersection) | ||
| if isinstance(ref_reco, pd.DataFrame): | ||
| for name, metric in intersection_metrics.items(): | ||
| results[name] = metric.calc(reco, ref_reco) | ||
| else: | ||
| for name, metric in intersection_metrics.items(): | ||
| for key, ref_r in ref_reco.items(): | ||
| results[f"{name}_{key}"] = metric.calc(reco, ref_r) | ||
|
|
||
| return results | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.