Skip to content

Commit

Permalink
Support model saving and loading (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Feb 6, 2020
1 parent 88d3124 commit d203d84
Show file tree
Hide file tree
Showing 73 changed files with 30,764 additions and 27,377 deletions.
6 changes: 5 additions & 1 deletion cornac/eval_methods/cross_validation.py
Expand Up @@ -128,13 +128,17 @@ def _next_fold(self):

def evaluate(self, model, metrics, user_based, show_validation):
result = CVResult(model.name)

for _ in range(self.n_folds):
self._get_train_test()
new_model = model.clone() # clone a completely new model
fold_result, _ = BaseMethod.evaluate(
self, model, metrics, user_based, show_validation=False
self, new_model, metrics, user_based, show_validation=False
)
result.append(fold_result)
self._next_fold()

result.organize()

return result, None # no validation result of CV

26 changes: 24 additions & 2 deletions cornac/experiment/experiment.py
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
# ============================================================================

import os
from datetime import datetime

from .result import ExperimentResult
from .result import CVExperimentResult
from ..metrics.rating import RatingMetric
Expand Down Expand Up @@ -42,6 +45,10 @@ class Experiment:
show_validation: bool, optional, default: True
Whether to show the results on validation set (if exists).
save_dir: str, optional, default: None
Path to a directory for storing trained models and logs. If None,
models will NOT be stored and logs will be saved in the current working directory.
Attributes
----------
Expand All @@ -63,13 +70,15 @@ def __init__(
user_based=True,
show_validation=True,
verbose=False,
save_dir=None,
):
self.eval_method = eval_method
self.models = self._validate_models(models)
self.metrics = self._validate_metrics(metrics)
self.user_based = user_based
self.show_validation = show_validation
self.verbose = verbose
self.save_dir = save_dir
self.result = None
self.val_result = None

Expand Down Expand Up @@ -110,6 +119,7 @@ def _create_result(self):
self.val_result = ExperimentResult()

def run(self):
"""Run the Cornac experiment"""
self._create_result()

for model in self.models:
Expand All @@ -119,11 +129,23 @@ def run(self):
user_based=self.user_based,
show_validation=self.show_validation,
)

self.result.append(test_result)
if self.val_result is not None:
self.val_result.append(val_result)

if not isinstance(self.result, CVExperimentResult):
model.save(self.save_dir)

output = ""
if self.val_result is not None:
print("\nVALIDATION:\n...\n{}".format(self.val_result))
output += "\nVALIDATION:\n...\n{}".format(self.val_result)
output += "\nTEST:\n...\n{}".format(self.result)

print(output)

print("\nTEST:\n...\n{}".format(self.result))
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")
save_dir = "." if self.save_dir is None else self.save_dir
output_file = os.path.join(save_dir, "CornacExp-{}.log".format(timestamp))
with open(output_file, "w") as f:
f.write(output)
4,791 changes: 2,403 additions & 2,388 deletions cornac/models/baseline_only/recom_bo.cpp

Large diffs are not rendered by default.

29 changes: 19 additions & 10 deletions cornac/models/baseline_only/recom_bo.pyx
Expand Up @@ -15,6 +15,8 @@

# cython: language_level=3

import multiprocessing

cimport cython
from cython.parallel import prange
from cython cimport floating, integral
Expand All @@ -23,8 +25,10 @@ from libc.math cimport abs

import numpy as np
cimport numpy as np
from tqdm import trange

from ..recommender import Recommender
from ...utils.init_utils import zeros


class BaselineOnly(Recommender):
Expand Down Expand Up @@ -74,17 +78,28 @@ class BaselineOnly(Recommender):
self.learning_rate = learning_rate
self.lambda_reg = lambda_reg
self.early_stop = early_stop
self.init_params = {} if init_params is None else init_params
self.seed = seed

import multiprocessing
if seed is not None:
self.num_threads = 1
elif num_threads > 0 and num_threads < multiprocessing.cpu_count():
self.num_threads = num_threads
else:
self.num_threads = multiprocessing.cpu_count()

# Init params if provided
self.init_params = {} if init_params is None else init_params
self.u_biases = self.init_params.get('Bu', None)
self.i_biases = self.init_params.get('Bi', None)
self.global_mean = 0.0

def _init(self):
n_users, n_items = self.train_set.num_users, self.train_set.num_items

self.global_mean = self.train_set.global_mean
self.u_biases = zeros(n_users) if self.u_biases is None else self.u_biases
self.i_biases = zeros(n_items) if self.i_biases is None else self.i_biases

def fit(self, train_set, val_set=None):
"""Fit the model to observations.
Expand All @@ -102,14 +117,9 @@ class BaselineOnly(Recommender):
"""
Recommender.fit(self, train_set, val_set)

n_users, n_items = train_set.num_users, train_set.num_items
self.global_mean = train_set.global_mean

from ...utils.init_utils import zeros
self.u_biases = self.init_params.get('Bu', zeros(n_users))
self.i_biases = self.init_params.get('Bi', zeros(n_items))

if self.trainable:
self._init()

(rid, cid, val) = train_set.uir_tuple
self._fit_sgd(rid, cid, val.astype(np.float32), self.u_biases, self.i_biases)

Expand Down Expand Up @@ -139,7 +149,6 @@ class BaselineOnly(Recommender):
floating r, r_pred, error, delta_loss
integral u, i, j

from tqdm import trange
progress = trange(max_iter, disable=not self.verbose)
for epoch in progress:
last_loss = loss
Expand Down

0 comments on commit d203d84

Please sign in to comment.