Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cornac/data/trainset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class MatrixTrainSet(TrainSet):
The dictionary containing mapping from original ids to mapped ids of items.

seed: int, optional, default: None
Random seed for reproduce data sampling.
Random seed for reproducing data sampling.

"""

Expand Down Expand Up @@ -179,7 +179,7 @@ def item_ppl_rank(self):

@classmethod
def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
global_ui_set=None, verbose=False):
global_ui_set=None, seed=None, verbose=False):
"""Constructing TrainSet from triplet data.

Parameters
Expand All @@ -196,6 +196,9 @@ def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
global_ui_set: :obj:`set`, optional, default: None
The global set of tuples (user, item). This helps avoiding duplicate observations.

seed: int, optional, default: None
Random seed for reproducing data sampling.

verbose: bool, default: False
The verbosity flag.

Expand Down Expand Up @@ -258,7 +261,7 @@ def from_uir(cls, data, global_uid_map=None, global_iid_map=None,
print('Min rating = {:.1f}'.format(min_rating))
print('Global mean = {:.1f}'.format(global_mean))

return cls(uir_tuple, max_rating, min_rating, global_mean, uid_map, iid_map)
return cls(uir_tuple, max_rating, min_rating, global_mean, uid_map, iid_map, seed=seed)

def num_batches(self, batch_size):
return estimate_batches(len(self.uir_tuple[0]), batch_size)
Expand Down
2 changes: 1 addition & 1 deletion cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _build_uir(self, train_data, test_data, val_data=None):
if self.verbose:
print('Building training set')
self.train_set = MultimodalTrainSet.from_uir(
train_data, self.global_uid_map, self.global_iid_map, global_ui_set, self.verbose)
train_data, self.global_uid_map, self.global_iid_map, global_ui_set, self.seed, self.verbose)

if self.verbose:
print('Building test set')
Expand Down