Skip to content

Commit

Permalink
Rename symbols to ensure consistency (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Dec 12, 2023
1 parent ea8352f commit e54190f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
22 changes: 15 additions & 7 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ class BasketDataset(Dataset):
uir_tuple: tuple, required
Tuple of 3 numpy arrays (user_indices, item_indices, rating_values).
basket_ids: numpy.array, required
basket_indices: numpy.array, required
Array of basket indices corresponding to observation in `uir_tuple`.
timestamps: numpy.array, optional, default: None
Expand Down Expand Up @@ -677,7 +677,7 @@ def __init__(
bid_map,
iid_map,
uir_tuple,
basket_ids=None,
basket_indices=None,
timestamps=None,
extra_data=None,
seed=None,
Expand All @@ -693,23 +693,31 @@ def __init__(
)
self.num_baskets = num_baskets
self.bid_map = bid_map
self.basket_ids = basket_ids
self.basket_indices = basket_indices
self.extra_data = extra_data
basket_sizes = list(Counter(basket_ids).values())
basket_sizes = list(Counter(basket_indices).values())
self.max_basket_size = np.max(basket_sizes)
self.min_basket_size = np.min(basket_sizes)
self.avg_basket_size = np.mean(basket_sizes)

self.__baskets = None
self.__basket_ids = None
self.__user_basket_data = None
self.__chrono_user_basket_data = None

@property
def basket_ids(self):
"""Return the list of raw basket ids"""
if self.__basket_ids is None:
self.__basket_ids = list(self.bid_map.keys())
return self.__basket_ids

@property
def baskets(self):
"""A dictionary to store indices where basket ID appears in the data."""
if self.__baskets is None:
self.__baskets = defaultdict(list)
for idx, bid in enumerate(self.basket_ids):
for idx, bid in enumerate(self.basket_indices):
self.__baskets[bid].append(idx)
return self.__baskets

Expand Down Expand Up @@ -836,7 +844,7 @@ def build(
np.ones(len(u_indices), dtype="float"),
)

basket_ids = np.asarray(b_indices, dtype="int")
basket_indices = np.asarray(b_indices, dtype="int")

timestamps = (
np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int")
Expand All @@ -854,7 +862,7 @@ def build(
bid_map=global_bid_map,
iid_map=global_iid_map,
uir_tuple=uir_tuple,
basket_ids=basket_ids,
basket_indices=basket_indices,
timestamps=timestamps,
extra_data=extra_data,
seed=seed,
Expand Down
4 changes: 2 additions & 2 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def get_gt_items(train_set, test_set, test_pos_items, exclude_unknowns):
user_idx,
item_indices,
history_baskets=history_baskets,
history_basket_ids=bids[:-1],
history_bids=bids[:-1],
uir_tuple=test_set.uir_tuple,
baskets=test_set.baskets,
basket_ids=test_set.basket_ids,
basket_indices=test_set.basket_indices,
extra_data=test_set.extra_data,
)

Expand Down
10 changes: 4 additions & 6 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,13 @@ def score(self, user_idx, history_baskets, **kwargs):

if self.use_personalized_popularity:
if self.use_quantity:
history_basket_bids = kwargs.get("history_basket_ids")
history_bids = kwargs.get("history_bids")
baskets = kwargs.get("baskets")
p_item_freq = Counter()
(_, item_ids, _) = kwargs.get("uir_tuple")
extra_data = kwargs.get("extra_data")
for bid in history_basket_bids:
ids = baskets[bid]
for idx in ids:
p_item_freq[item_ids[idx]] += extra_data[idx].get("quantity", 0)
for bid, iids in zip(history_bids, history_baskets):
for idx, iid in zip(baskets[bid], iids):
p_item_freq[iid] += extra_data[idx].get("quantity", 0)
else:
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])
for iid, cnt in p_item_freq.most_common():
Expand Down

0 comments on commit e54190f

Please sign in to comment.