In [14]:
from models.sansa import SANSA

sansa_config = {
    "l2": 2.5,
    "target_density": 0.0005,
    "ainv_params": {
        "umr_scans": 4,
        "umr_finetune_steps": 10,
        "umr_loss_threshold": 1e-4,
    },
    "ldlt_method": "icf",
    "ldlt_params": {},
}
     
sansa = SANSA.from_config(sansa_config)


In [11]:
# Load msd data. Takes about 6 minutes on CSEL-CUDA-03.cselabs.umn.edu

from datasets.msd import MSD

msd_data_config = {
    "name": "msd",
    "rewrite": False,
}

msd_dataset = MSD.from_config(msd_data_config)

2024-04-27 22:19:12,660 : [1/3] DATASET : Loading processed dataset datasets/data/msd/dataset.parquet.


In [12]:
msd_split_config = {
    "n_val_users": 50000,
    "n_test_users": 50000,
    "seed": 42,
    "target_proportion": 0.2,
    "targets_newest": False,
}

(msd_train, msd_val, msd_test), msd_split_time = msd_dataset.create_splits(msd_split_config)

2024-04-27 22:19:38,833 : [1/3] DATASET : Dataframe lengths | train_df: 27728200, val_df: 2953824, test_df: 2951425
2024-04-27 22:21:18,509 : [1/3] DATASET : Splits information:
2024-04-27 22:21:18,513 : [1/3] DATASET : Train split info | n_users = 471355, n_items = 41140, n_ratings = 27728200, sparsity = 99.86%
2024-04-27 22:21:18,515 : [1/3] DATASET : Validation split info | n_users = 50000, n_items = 41140, n_ratings = 2382298, sparsity = 99.88%
2024-04-27 22:21:18,516 : [1/3] DATASET : Test split info | n_users = 50000, n_items = 41140, n_ratings = 2380405, sparsity = 99.88%
2024-04-27 22:21:18,517 : [1/3] DATASET : Execution of create_splits took at 117.312 seconds.


In [15]:
# Train Sansa
sansa.train(msd_train)

2024-04-27 22:31:41,512 : [2/3] TRAINING : Train user-item matrix info | n_users = 471355, n_items = 41140, n_ratings = 27728200, sparsity = 99.86%
2024-04-27 22:31:41,515 : [2/3] TRAINING : Item-item matrix info | shape = (41140,41140)
2024-04-27 22:31:41,516 : [2/3] TRAINING : Training SANSA with L2=2.5, target density=0.050000%, LDL^T method=icf, approx. inverse method=umr...
2024-04-27 22:31:41,518 : [2/3] TRAINING : Loading item-user matrix...
2024-04-27 22:31:43,804 : [2/3] TRAINING : Constructing weights:
2024-04-27 22:31:58,694 : [2/3] TRAINING : Constructing A...
2024-04-27 22:32:05,207 : [2/3] TRAINING : A info | nnz: 703833604, size: 8446.2 MB
2024-04-27 22:32:36,102 : [2/3] TRAINING : Computing incomplete LL^T decomposition...
2024-04-27 22:34:58,575 : [2/3] TRAINING : L info | nnz: 846061, size: 10.317 MB, density: 0.049989%
2024-04-27 22:34:58,577 : [2/3] TRAINING : Scaling columns and creating D (LL^T -> L'DL'^T)
2024-04-27 22:34:58,596 : [2/3] TRAINING : Execution of ld

In [16]:
import pandas as pd

# Evaluate on novelty
# Get all users
users = list(msd_test.user_encoder.classes_)
# Get rated items of users
users_rated = msd_test.get_rated_items(users)
targets = msd_test.get_target_items(users)
target_ids_dict = (
    targets.groupby("user_id", group_keys=True)["item_id"]
    .apply(list)
    .to_dict()
)
keys = list(target_ids_dict.keys())
users_to_arange = {user: i for i, user in enumerate(keys)}
pd.options.mode.chained_assignment = None  # suppress irrelevant warning
users_rated["user_id"] = users_rated["user_id"].map(users_to_arange)
pd.options.mode.chained_assignment = "warn"
top_maxk_ids, top_maxk_scores = sansa.recommend(users_rated, 20)

2024-04-27 22:40:02,655 : [3/3] EVALUATION : Execution of _matmat took at 0.186 seconds.
2024-04-27 22:40:06,510 : [3/3] EVALUATION : Execution of _matmat took at 3.850 seconds.
2024-04-27 22:40:20,405 : [3/3] EVALUATION : Execution of _predict took at 17.935 seconds.


In [18]:
# Create dictionary of item occurences for novelty metric
training_csr_matrix = msd_train.get_csr_matrix()
item_occurrences = training_csr_matrix.sum(axis=0)
item_ids = item_occurrences.nonzero()[1]
# Get the occurrences as a numpy array
item_occurrences = item_occurrences.A1
    
# Create a dictionary of item IDs and their occurrences
item_occurrences_dict = dict(zip(item_ids, item_occurrences))

In [19]:
import recmetrics

msd_novelty, msd_novelty_topn = recmetrics.novelty(top_maxk_ids, item_occurrences_dict, len(users), 20)

msd_novelty

4.644236306714888