diff --git a/README.rst b/README.rst index f0e4d42..f35e9c2 100644 --- a/README.rst +++ b/README.rst @@ -192,7 +192,7 @@ Inheritance is powerful, as we can build new suites by re-using existing ones. H Any model can be tested, as no assumption is made on the model's structure, but only the availability of *predictions* and *ground truth*. Once again, while our example leverages a DataFrame-shaped dataset for these entities, you are free to build your own -RecList instance with any shape you prefer, provided you implement the metrics accordingly (see the `examples/dummy.py` script for an example with different input types). +RecList instance with any shape you prefer, provided you implement the metrics accordingly (see *dummy.py* for an example with different input types). Once you run a suite of tests, results are dumped automatically and versioned in a folder (local or on S3), structured as follows (name of the suite, name of the model, run timestamp): @@ -220,10 +220,51 @@ based on DataFrames to make existing tests and metrics fully re-usable, but we d * flexible, Python interface to declare tests-as-functions, and annotate them with *display_type* for automated charts; -* pre-built connectors with popular experiment trackers (e.g. Neptune, Comet), and an extensible interface to add your own (see the scripts in the `examples` folder for snippets on how to use third-party trackers); +* pre-built connectors with popular experiment trackers (e.g. Neptune, Comet), and an extensible interface to add your own (see below); * reference implementations based on popular data challenges that used RecList: for an example of the "less wrong" latent space metric you can check the song2vec implementation `here `__. +Using Third-Party Tracking Tools +-------------------------------- + +*RecList* supports streaming the results of your tests directly to your cloud platform of choice, both as metrics and charts. + +If you have the `Python client installed `__, you can use the +Neptune logger by simply specifying it at init time, and either passing *NEPTUNE_KEY* and *NEPTUNE_PROJECT_NAME* as kwargs, or setting them as environment variables. + +.. code-block:: python + + cdf = DFSessionRecList( + dataset=df_events, + model_name="myDataFrameRandomModel", + predictions=df_predictions, + y_test=df_dataset, + logger=LOGGER.NEPTUNE, + metadata_store= METADATA_STORE.LOCAL, + similarity_model=my_sim_model + ) + + cdf(verbose=True) + +If you have the `Python client installed `__, you can use the +Comet logger by simply specifying it at init time, and either passing *COMET_KEY*, *COMET_PROJECT_NAME*, *COMET_WORKSPACE* as kwargs, or setting them as environment variables. + +.. code-block:: python + + cdf = DFSessionRecList( + dataset=df_events, + model_name="myDataFrameRandomModel", + predictions=df_predictions, + y_test=df_dataset, + logger=LOGGER.COMET, + metadata_store= METADATA_STORE.LOCAL, + similarity_model=my_sim_model + ) + + cdf(verbose=True) + + +If you wish to add a new platform, you can do so by simply implementing a new class inheriting from RecLogger. Acknowledgments --------------- diff --git a/examples/evalrs_2023.py b/examples/evalrs_2023.py index 1bbd13c..b0ad214 100644 --- a/examples/evalrs_2023.py +++ b/examples/evalrs_2023.py @@ -21,8 +21,10 @@ import numpy as np import os from reclist.reclist import rec_test -from reclist.reclist import RecList +from reclist.reclist import RecList, CHART_TYPE from random import choice +from gensim.models import KeyedVectors + class DFSessionRecList(RecList): @@ -53,42 +55,80 @@ def __init__( self.dataset = dataset self._y_preds = predictions self._y_test = kwargs.get("y_test", None) + self._user_metadata = kwargs.get("user_metadata", None) + if not isinstance(self._user_metadata, type(None)): + self._user_metadata = self._user_metadata.set_index("user_id") self.similarity_model = kwargs.get("similarity_model", None) - return @rec_test(test_type='HIT_RATE') def hit_rate_at_100(self): - hr = self.hit_rate_at_k(self._y_preds, self._y_test, k=100) + from reclist.metrics.standard_metrics import hit_rate_at_k + hr = hit_rate_at_k(self._y_preds, self._y_test, k=100) return hr - def hit_rate_at_k(self, y_pred: pd.DataFrame, y_test: pd.DataFrame, k: int): - """ - N = number test cases - M = number ground truth per test case - """ - hits = self.hits_at_k(y_pred, y_test, k) # N x M x k - hits = hits.max(axis=1) # N x k - return hits.max(axis=1).mean() # 1 - - def hits_at_k(self, y_pred: pd.DataFrame, y_test: pd.DataFrame, k: int): - """ - N = number test cases - M = number ground truth per test case - """ - y_test_mask = y_test.values != -1 # N x M - - y_pred_mask = y_pred.values[:, :k] != -1 # N x k - - y_test = y_test.values[:, :, None] # N x M x 1 - y_pred = y_pred.values[:, None, :k] # N x 1 x k - - hits = y_test == y_pred # N x M x k - hits = hits * y_test_mask[:, :, None] # N x M x k - hits = hits * y_pred_mask[:, None, :] # N x M x k - - return hits - + @rec_test(test_type='MRR') + def mrr_at_100(self): + from reclist.metrics.standard_metrics import mrr_at_k + + return mrr_at_k(self._y_preds, self._y_test, k=100) + + @rec_test(test_type='MRED_COUNTRY', display_type=CHART_TYPE.BARS) + def mred_country(self): + country_list = ["US", "RU", "DE", "UK", "PL", "BR", "FI", "NL", "ES", "SE", "UA", "CA", "FR", "NaN"] + + user_countries = self._user_metadata.loc[self._y_test.index, ['country']].fillna('NaN') + valid_country_mask = user_countries['country'].isin(country_list) + y_pred_valid = self._y_preds[valid_country_mask] + y_test_valid = self._y_test[valid_country_mask] + user_countries = user_countries[valid_country_mask] + + return self.miss_rate_equality_difference(y_pred_valid, y_test_valid, user_countries, 'country') + + @rec_test(test_type='BEING_LESS_WRONG') + def being_less_wrong(self): + from reclist.metrics.standard_metrics import hits_at_k + + hits = hits_at_k(self._y_preds, self._y_test, k=100).max(axis=2) + misses = (hits == False) + miss_gt_vectors = self.similarity_model[self._y_test.loc[misses, 'track_id'].values.reshape(-1)] + # we calculate the score w.r.t to the first prediction + miss_pred_vectors = self.similarity_model[self._y_preds.loc[misses, '0'].values.reshape(-1)] + + return float(self.cosine_sim(miss_gt_vectors, miss_pred_vectors).mean()) + + def cosine_sim(self, u: np.array, v: np.array) -> np.array: + return np.sum(u * v, axis=-1) / (np.linalg.norm(u, axis=-1) * np.linalg.norm(v, axis=-1)) + + def miss_rate_at_k_slice(self, + y_preds: pd.DataFrame, + y_test: pd.DataFrame, + slice_info: pd.DataFrame, + slice_key: str): + from reclist.metrics.standard_metrics import misses_at_k + # get false positives + m = misses_at_k(y_preds, y_test, k=100).min(axis=2) + # convert to dataframe + m = pd.DataFrame(m, columns=['mr'], index=y_test.index) + # grab slice info + m[slice_key] = slice_info[slice_key].values + # group-by slice and get per-slice mrr + return m.groupby(slice_key)['mr'].agg('mean') + + def miss_rate_equality_difference(self, + y_preds: pd.DataFrame, + y_test: pd.DataFrame, + slice_info: pd.DataFrame, + slice_key: str): + from reclist.metrics.standard_metrics import misses_at_k + + mr_per_slice = self.miss_rate_at_k_slice(y_preds, y_test, slice_info, slice_key) + mr = misses_at_k(y_preds, y_test, k=100).min(axis=2).mean() + # take negation so that higher values => better fairness + mred = -(mr_per_slice-mr).abs().mean() + res = mr_per_slice.to_dict() + return {'mred': mred, 'mr': mr, **res} + class EvalRSSimpleModel: """ @@ -119,7 +159,8 @@ def predict(self, user_ids: pd.DataFrame) -> pd.DataFrame: df_tracks = pd.read_parquet('evalrs_dataset_KDD_2023/evalrs_tracks.parquet').set_index('track_id') df_users = pd.read_parquet('evalrs_dataset_KDD_2023/evalrs_users.parquet') - print(df_users['user_id'].head()) + similarity_model = KeyedVectors.load('evalrs_dataset_KDD_2023/song2vec.wv') + """ Here we would normally train a model, but we just return random predictions. """ @@ -129,14 +170,15 @@ def predict(self, user_ids: pd.DataFrame) -> pd.DataFrame: all_tracks = df_tracks.index.values df_dataset = pd.DataFrame( { + 'user_id': df_predictions.index.tolist(), 'track_id': [choice(all_tracks) for _ in range(len(df_predictions))] } - ) + ).set_index('user_id') + """ Here we use RecList to run the evaluation. """ - # initialize with everything cdf = DFSessionRecList( dataset=df_events, @@ -146,10 +188,8 @@ def predict(self, user_ids: pd.DataFrame) -> pd.DataFrame: y_test=df_dataset, logger=LOGGER.LOCAL, metadata_store=METADATA_STORE.LOCAL, - # bucket=os.environ["S3_BUCKET"], # if METADATA_STORE.LOCAL you don't need this! - #NEPTUNE_KEY=os.environ["NEPTUNE_KEY"], # if LOGGER.NEPTUNE, make sure you have the env - #NEPTUNE_PROJECT_NAME=os.environ["NEPTUNE_PROJECT_NAME"] # if LOGGER.NEPTUNE, make sure you have the env + similarity_model=similarity_model, + user_metadata=df_users, ) - # run reclist cdf(verbose=True) diff --git a/reclist/metrics/standard_metrics.py b/reclist/metrics/standard_metrics.py index e6719f9..428bd33 100644 --- a/reclist/metrics/standard_metrics.py +++ b/reclist/metrics/standard_metrics.py @@ -122,10 +122,16 @@ def ranks_at_k( [2, 0, 1]]) """ - hits = hits_at_k(y_pred, y_test, k) # N x M x k - ranks = hits * np.arange(1, k + 1, 1)[None, None, :] # N x M x k - ranks = ranks.max(axis=2) # N x M + # TODO: hits_at_k can be modified to return df with last dim=k instead of preds shape + rank_overlap = min(k, hits.shape[-1]) + ranks = hits * np.arange(1, rank_overlap + 1, 1)[None, None, :] # N x M x k + # set to float + ranks = ranks.astype(float) + # set non-hits to infinity + ranks[ranks==0] = np.inf + # get highest rank; if no hit, rank is infinite + ranks = ranks.min(axis=2) # N x M return ranks @@ -239,7 +245,8 @@ def rr_at_k( """ ranks = ranks_at_k(y_pred, y_test, k).astype(np.float64) # N x M - reciprocal_ranks = np.reciprocal(ranks, out=ranks, where=ranks > 0) # N x M + reciprocal_ranks = np.reciprocal(ranks, out=ranks,) # N x M + # reciprocal_ranks = np.reciprocal(ranks, out=ranks, where=ranks > 0) # N x M return reciprocal_ranks.max(axis=1) # N diff --git a/tests/test_reclist.py b/tests/test_reclist.py index 73201dd..73d1234 100644 --- a/tests/test_reclist.py +++ b/tests/test_reclist.py @@ -81,6 +81,15 @@ def test_mrr(): [[10, 12, 14, None, None, None], [22, 8, 64, 13, 1, 0]] ) + df_f = pd.DataFrame( + [[10, 12, 14, None, None, None], + [22, 1, 64, 13, 1, 0]] + ) + + df_g = pd.DataFrame( + [[10, 12, 14, None, None, None], + [22, 17, 64, 13, 1, 0]] + ) # df_f = pd.DataFrame( # [[2, 3], # [0, 1]] @@ -95,4 +104,13 @@ def test_mrr(): assert mrr_at_k(df_e, df_d, 2) == 1/4 assert mrr_at_k(df_e, df_d, 3) == pytest.approx(5/12) assert mrr_at_k(df_e, df_d, 6) == pytest.approx(5/12) + + # k larger than pred size + assert mrr_at_k(df_e, df_d, 20) == pytest.approx(5/12) + + # repeated prediction that is a hit + assert mrr_at_k(df_f, df_d, 6) == pytest.approx(5/12) + + assert mrr_at_k(df_g, df_d, 6) == pytest.approx(4/15) + \ No newline at end of file