Skip to content

Commit

Permalink
Re draft auto module.
Browse files Browse the repository at this point in the history
  • Loading branch information
uduse committed Dec 10, 2018
1 parent 617bad8 commit af98856
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 148 deletions.
2 changes: 1 addition & 1 deletion matchzoo/auto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tune import tune
from .director import Director
from .prepare import prepare
119 changes: 0 additions & 119 deletions matchzoo/auto/director.py

This file was deleted.

60 changes: 60 additions & 0 deletions matchzoo/auto/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import copy

import numpy as np

from matchzoo import tasks
from matchzoo import models


def prepare(
model,
train_pack,
preprocessor=None,
verbose=1
):
"""
:param model:
:param train_pack:
:param preprocessor:
:param verbose:
:return:
"""
params = copy.deepcopy(model.params)
if preprocessor:
new_preprocessor = copy.deepcopy(preprocessor)
else:
new_preprocessor = model.get_default_preprocessor()

train_pack_processed = new_preprocessor.fit_transform(train_pack, verbose)

if not params['task']:
params['task'] = _guess_task(train_pack)

context = {}
if 'input_shapes' in new_preprocessor.context:
context['input_shapes'] = new_preprocessor.context['input_shapes']

if isinstance(model, models.DSSMModel):
params['input_shapes'] = context['input_shapes']

if 'with_embedding' in params:
term_index = new_preprocessor.context['vocab_unit'].state['term_index']
vocab_size = len(term_index) + 1
params['embedding_input_dim'] = vocab_size

new_model = type(model)(params=params)
new_model.guess_and_fill_missing_params(verbose=verbose)
new_model.build()
new_model.compile()

return new_model, train_pack_processed, new_preprocessor


def _guess_task(train_pack):
if np.issubdtype(train_pack.relation['label'].dtype, np.number):
return tasks.Ranking()
elif np.issubdtype(train_pack.relation['label'].dtype, list):
num_classes = int(train_pack.relation['label'].apply(len).max())
return tasks.Classification(num_classes)
41 changes: 13 additions & 28 deletions matchzoo/auto/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import hyperopt

from matchzoo import engine
from matchzoo import models


def tune(
model: engine.BaseModel,
train_pack,
test_pack,
task,
context=None,
max_evals: int = 32,
verbose=1
) -> list:
Expand All @@ -26,52 +23,40 @@ def tune(
:param model: Model to tune.
:param train_pack: :class:`matchzoo.DataPack` to train the model.
:param test_pack: :class:`matchzoo.DataPack` to test the model.
:param task: :class:`matchzoo.engine.BaseTask` to execute.
:param context: Extra information for tunning. Different for different
models.
:param max_evals: Number of evaluations of a single tuning process.
:param verbose: Verbosity.
:return: A list of trials of the tuning process.
"""

def _test_wrapper(space):
for key, value in space.items():
model.params[key] = value

if isinstance(model, models.DSSMModel):
input_shapes = context['input_shapes']
model.params['input_shapes'] = input_shapes

if 'with_embedding' in model.params:
model.params['embedding_input_dim'] = context['vocab_size']

model.params['task'] = task
model.guess_and_fill_missing_params(verbose=verbose)
model.build()
model.compile()

model.fit(*train_pack.unpack(), verbose=verbose)
metrics = model.evaluate(*test_pack.unpack(), verbose=verbose)

results = _eval_model()
return {
'loss': metrics['loss'],
'loss': results['loss'],
'space': space,
'status': hyperopt.STATUS_OK,
'model_params': model.params
}

def _eval_model():
model.build()
model.compile()
model.fit(*train_pack.unpack(), verbose=verbose)
return model.evaluate(*test_pack.unpack(), verbose=verbose)

if not model.params.hyper_space:
raise ValueError("Model hyper parameter space empty.")

trials = hyperopt.Trials()
hyper_space = model.params.hyper_space
if not hyper_space:
raise ValueError("Cannot auto-tune on an empty hyper space.")
hyperopt.fmin(
fn=_test_wrapper,
space=hyper_space,
space=model.params.hyper_space,
algo=hyperopt.tpe.suggest,
max_evals=max_evals,
trials=trials
)

return [_clean_up_trial(trial) for trial in trials]


Expand Down

0 comments on commit af98856

Please sign in to comment.