diff --git a/recbole_cdr/data/dataset.py b/recbole_cdr/data/dataset.py index 31b03e2..31058bc 100644 --- a/recbole_cdr/data/dataset.py +++ b/recbole_cdr/data/dataset.py @@ -22,6 +22,7 @@ from recbole.data.dataset import Dataset from recbole.utils import FeatureSource, FeatureType, set_color +from recbole_cdr.utils import get_keys_from_chainmap_by_order class CrossDomainSingleDataset(Dataset): @@ -112,7 +113,7 @@ def _remap_fields(self, field_names, map_dict): map_dict (dict): The dict whose keys are the original ids and values are the new ids. """ for field_name in field_names: - self.field2id_token[field_name] = list(map_dict.keys()) + self.field2id_token[field_name] = get_keys_from_chainmap_by_order(map_dict) self.field2token_id[field_name] = map_dict if field_name in self.inter_feat.columns: self.inter_feat[field_name] = self.inter_feat[field_name].map(lambda x: map_dict.get(x, x))