Skip to content

Commit

Permalink
Merge pull request #56 from Wicknight/main
Browse files Browse the repository at this point in the history
FIX: CI test, BiTGCF, map_dict
  • Loading branch information
Wicknight committed May 13, 2023
2 parents 89be22c + d038632 commit 5bb219d
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
python -m pip install --upgrade pip
pip install pytest
pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install recbole
pip install recbole==1.0.1
conda list
pip install setuptools==59.5.0
pip install protobuf~=3.19.0
Expand Down
12 changes: 12 additions & 0 deletions recbole_cdr/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, model=None, config_file_list=None, config_dict=None):
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
config_dict (dict): the external parameter dictionaries, default is None.
"""
self.compatibility_settings()
self._init_parameters_category()
self.parameters['Dataset'] += ['source_domain', 'target_domain']
self.yaml_loader = self._build_yaml_loader()
Expand Down Expand Up @@ -283,3 +284,14 @@ def update(self, other_config):
for key in other_config:
new_config_obj.final_config_dict[key] = other_config[key]
return new_config_obj

def compatibility_settings(self):
import numpy as np
np.bool = np.bool_
np.int = np.int_
np.float = np.float_
np.complex = np.complex_
np.object = np.object_
np.str = np.str_
np.long = np.int_
np.unicode = np.unicode_
3 changes: 2 additions & 1 deletion recbole_cdr/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from recbole.data.dataset import Dataset
from recbole.utils import FeatureSource, FeatureType, set_color
from recbole_cdr.utils import get_keys_from_chainmap_by_order


class CrossDomainSingleDataset(Dataset):
Expand Down Expand Up @@ -112,7 +113,7 @@ def _remap_fields(self, field_names, map_dict):
map_dict (dict): The dict whose keys are the original ids and values are the new ids.
"""
for field_name in field_names:
self.field2id_token[field_name] = list(map_dict.keys())
self.field2id_token[field_name] = get_keys_from_chainmap_by_order(map_dict)
self.field2token_id[field_name] = map_dict
if field_name in self.inter_feat.columns:
self.inter_feat[field_name] = self.inter_feat[field_name].map(lambda x: map_dict.get(x, x))
Expand Down
4 changes: 2 additions & 2 deletions recbole_cdr/model/cross_domain_recommender/bitgcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ def transfer_layer(self, source_all_embeddings, target_all_embeddings):

source_user_laplace = self.source_user_degree_count
target_user_laplace = self.target_user_degree_count
user_laplace = source_user_laplace + target_user_laplace
user_laplace = source_user_laplace + target_user_laplace + 1e-7
source_user_embeddings_lap = (source_user_laplace * source_user_embeddings + target_user_laplace * target_user_embeddings) / user_laplace
target_user_embeddings_lap = source_user_embeddings_lap
source_item_laplace = self.source_item_degree_count
target_item_laplace = self.target_item_degree_count
item_laplace = source_item_laplace + target_item_laplace
item_laplace = source_item_laplace + target_item_laplace + 1e-7
source_item_embeddings_lap = (source_item_laplace * source_item_embeddings + target_item_laplace * target_item_embeddings) / item_laplace
target_item_embeddings_lap = source_item_embeddings_lap

Expand Down
4 changes: 2 additions & 2 deletions recbole_cdr/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from recbole_cdr.utils.utils import get_model, get_trainer
from recbole_cdr.utils.utils import get_model, get_trainer, get_keys_from_chainmap_by_order
from recbole_cdr.utils.enum_type import *

__all__ = [
'get_model', 'get_trainer', 'Enum', 'ModelType', 'CrossDomainDataLoaderState', 'train_mode2state'
'get_model', 'get_trainer', 'Enum', 'ModelType', 'CrossDomainDataLoaderState', 'train_mode2state', 'get_keys_from_chainmap_by_order'
]
7 changes: 7 additions & 0 deletions recbole_cdr/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ def get_trainer(model_type, model_name):
return getattr(importlib.import_module('recbole_cdr.trainer'), 'CrossDomainTrainer')
else:
return getattr(importlib.import_module('recbole.trainer'), 'Trainer')


def get_keys_from_chainmap_by_order(map_dict):
merged_dict = dict()
for dict_item in map_dict.maps:
merged_dict.update(dict_item)
return list(merged_dict.keys())

0 comments on commit 5bb219d

Please sign in to comment.