Skip to content
Permalink
Browse files

Merge pull request #66 from Accenture/feature/65

Feature/65
  • Loading branch information...
NicholasMcCarthy committed Aug 15, 2019
2 parents ce8e496 + 9eb99e8 commit 2418021caf26b1d47f66656adc54785fc1833aac
Showing with 17 additions and 3 deletions.
  1. +8 −2 ampligraph/utils/model_utils.py
  2. +9 −1 tests/ampligraph/utils/test_model_utils.py
@@ -134,7 +134,7 @@ def restore_model(model_name_path=None):
with open(model_name_path, 'rb') as fr:
restored_obj = pickle.load(fr)

logger.debug('Restoring model...')
logger.debug('Restoring model ...')
module = importlib.import_module("ampligraph.latent_features.models")
class_ = getattr(module, restored_obj['class_name'])
model = class_(**restored_obj['hyperparams'])
@@ -143,7 +143,13 @@ def restore_model(model_name_path=None):
model.rel_to_idx = restored_obj['rel_to_idx']
model.restore_model_params(restored_obj)
except (IOError, pickle.UnpicklingError) as e:
logger.debug('No model found: {}.'.format(e))
msg = 'Error unpickling model {} : {}.'.format(model_name_path, e)
logger.debug(msg)
raise Exception(msg)
except FileNotFoundError:
msg = 'No model found: {}.'.format(model_name_path)
logger.debug(msg)
raise FileNotFoundError(msg)

return model

@@ -10,9 +10,11 @@
import numpy as np
import numpy.testing as npt
from ampligraph.utils import save_model, restore_model, create_tensorboard_visualizations, write_metadata_tsv

import pytest
import pickle

def test_save_and_restore_model():

models = ('ComplEx', 'TransE', 'DistMult')

for model_name in models:
@@ -61,6 +63,12 @@ def test_save_and_restore_model():
os.remove(example_name)


def test_restore_model_errors():

with pytest.raises(FileNotFoundError):
model = restore_model(model_name_path='filenotfound.model')


def test_create_tensorboard_visualizations():
# TODO: This
pass

0 comments on commit 2418021

Please sign in to comment.
You can’t perform that action at this time.