Skip to content

Commit

Permalink
I'll merge these commits later. Promise.
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladislav Zavadskyy committed May 25, 2018
1 parent 2969c6c commit 4fd1fbb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
7 changes: 2 additions & 5 deletions nasframe/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,15 @@ def find_best(filename, config, preprocess, draw, save_to):
@click.option('--force-perprocess', help=
'Will force preprocessing, even if preprocessed data exists.',
is_flag=True, default=False)
@click.option('--curriculum', help=
'If set, will perform curriculum architect training.',
is_flag=True, default=False)
def toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess, curriculum):
def toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess):
"""
Preforms neural architecture search on Jigsaw Toxic Comment dataset
"""
assert 0 < val_fraction < 1, 'Validation data fraction has to be in range (0,1).'
assert num_gpus >= 1, 'Number of GPUs has to be >= 1.'

from .toxiccomment import train_toxic
train_toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess, curriculum)
train_toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess)


@nas.command()
Expand Down
3 changes: 2 additions & 1 deletion nasframe/scripts/toxiccomment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def prepare_data(data_dir, val_fraction, embedding_path, embedding_dim, force_pe
return embedding


def train_toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess, curriculum):
def train_toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perprocess):
"""
Trains architect (performs NAS) of Jigsaw Toxic Comment dataset.
See cli help, for parameter description.
Expand All @@ -102,6 +102,7 @@ def train_toxic(num_gpus, val_fraction, resume, config_path, gpu_idx, force_perp
)

input_shape = config['child_training'].pop('input_shape')
curriculum = config['architect_training'].get('curriculum', False)
if curriculum:
storage = train_curriculum(
config, worker_fn,
Expand Down

0 comments on commit 4fd1fbb

Please sign in to comment.