diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5a7236e..2ab3552 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -36,7 +36,7 @@ jobs: python -m pip install --upgrade pip pip install pytest pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - pip install recbole + pip install recbole==1.0.1 conda list pip install setuptools==59.5.0 pip install protobuf~=3.19.0 diff --git a/recbole_cdr/config/configurator.py b/recbole_cdr/config/configurator.py index efe4197..47bd8bf 100644 --- a/recbole_cdr/config/configurator.py +++ b/recbole_cdr/config/configurator.py @@ -56,6 +56,7 @@ def __init__(self, model=None, config_file_list=None, config_dict=None): config_file_list (list of str): the external config file, it allows multiple config files, default is None. config_dict (dict): the external parameter dictionaries, default is None. """ + self.compatibility_settings() self._init_parameters_category() self.parameters['Dataset'] += ['source_domain', 'target_domain'] self.yaml_loader = self._build_yaml_loader() @@ -283,3 +284,14 @@ def update(self, other_config): for key in other_config: new_config_obj.final_config_dict[key] = other_config[key] return new_config_obj + + def compatibility_settings(self): + import numpy as np + np.bool = np.bool_ + np.int = np.int_ + np.float = np.float_ + np.complex = np.complex_ + np.object = np.object_ + np.str = np.str_ + np.long = np.int_ + np.unicode = np.unicode_ diff --git a/recbole_cdr/data/dataset.py b/recbole_cdr/data/dataset.py index 31b03e2..31058bc 100644 --- a/recbole_cdr/data/dataset.py +++ b/recbole_cdr/data/dataset.py @@ -22,6 +22,7 @@ from recbole.data.dataset import Dataset from recbole.utils import FeatureSource, FeatureType, set_color +from recbole_cdr.utils import get_keys_from_chainmap_by_order class CrossDomainSingleDataset(Dataset): @@ -112,7 +113,7 @@ def _remap_fields(self, field_names, map_dict): map_dict (dict): The dict whose keys are the original ids and values are the new ids. """ for field_name in field_names: - self.field2id_token[field_name] = list(map_dict.keys()) + self.field2id_token[field_name] = get_keys_from_chainmap_by_order(map_dict) self.field2token_id[field_name] = map_dict if field_name in self.inter_feat.columns: self.inter_feat[field_name] = self.inter_feat[field_name].map(lambda x: map_dict.get(x, x)) diff --git a/recbole_cdr/model/cross_domain_recommender/bitgcf.py b/recbole_cdr/model/cross_domain_recommender/bitgcf.py index 22e98dd..43aba68 100644 --- a/recbole_cdr/model/cross_domain_recommender/bitgcf.py +++ b/recbole_cdr/model/cross_domain_recommender/bitgcf.py @@ -145,12 +145,12 @@ def transfer_layer(self, source_all_embeddings, target_all_embeddings): source_user_laplace = self.source_user_degree_count target_user_laplace = self.target_user_degree_count - user_laplace = source_user_laplace + target_user_laplace + user_laplace = source_user_laplace + target_user_laplace + 1e-7 source_user_embeddings_lap = (source_user_laplace * source_user_embeddings + target_user_laplace * target_user_embeddings) / user_laplace target_user_embeddings_lap = source_user_embeddings_lap source_item_laplace = self.source_item_degree_count target_item_laplace = self.target_item_degree_count - item_laplace = source_item_laplace + target_item_laplace + item_laplace = source_item_laplace + target_item_laplace + 1e-7 source_item_embeddings_lap = (source_item_laplace * source_item_embeddings + target_item_laplace * target_item_embeddings) / item_laplace target_item_embeddings_lap = source_item_embeddings_lap diff --git a/recbole_cdr/utils/__init__.py b/recbole_cdr/utils/__init__.py index 3043914..5064dcf 100644 --- a/recbole_cdr/utils/__init__.py +++ b/recbole_cdr/utils/__init__.py @@ -1,6 +1,6 @@ -from recbole_cdr.utils.utils import get_model, get_trainer +from recbole_cdr.utils.utils import get_model, get_trainer, get_keys_from_chainmap_by_order from recbole_cdr.utils.enum_type import * __all__ = [ - 'get_model', 'get_trainer', 'Enum', 'ModelType', 'CrossDomainDataLoaderState', 'train_mode2state' + 'get_model', 'get_trainer', 'Enum', 'ModelType', 'CrossDomainDataLoaderState', 'train_mode2state', 'get_keys_from_chainmap_by_order' ] diff --git a/recbole_cdr/utils/utils.py b/recbole_cdr/utils/utils.py index 971e684..d41b418 100644 --- a/recbole_cdr/utils/utils.py +++ b/recbole_cdr/utils/utils.py @@ -57,3 +57,10 @@ def get_trainer(model_type, model_name): return getattr(importlib.import_module('recbole_cdr.trainer'), 'CrossDomainTrainer') else: return getattr(importlib.import_module('recbole.trainer'), 'Trainer') + + +def get_keys_from_chainmap_by_order(map_dict): + merged_dict = dict() + for dict_item in map_dict.maps: + merged_dict.update(dict_item) + return list(merged_dict.keys())