Skip to content

Commit

Permalink
remove opt from make_base_model signature, pass gpuid explicitly, sim…
Browse files Browse the repository at this point in the history
…plify model and generator creation
  • Loading branch information
bpopeters committed Sep 12, 2017
1 parent 7d7cd9a commit 8df0e5e
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions onmt/ModelConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def make_decoder(opt, embeddings):
embeddings)


def make_base_model(opt, model_opt, fields, checkpoint=None):
def make_base_model(model_opt, fields, gpuid, checkpoint=None):
"""
Args:
opt: the option in current environment.
model_opt: the option loaded from checkpoint.
fields: `Field` objects for the model.
gpuid: a list of integers, one for each gpu
checkpoint: the snapshot model.
Returns:
the NMTModel.
Expand Down Expand Up @@ -163,26 +163,19 @@ def make_base_model(opt, model_opt, fields, checkpoint=None):
generator = CopyGenerator(model_opt, fields["src"].vocab,
fields["tgt"].vocab)

# Load the modle states from checkpoint.
# Load the model states from checkpoint.
if checkpoint is not None:
print('Loading model')
model.load_state_dict(checkpoint['model'])
generator.load_state_dict(checkpoint['generator'])

# add the generator to the module (does this register the parameter?)
model.generator = generator

# Make the whole model leverage GPU if indicated to do so.
if hasattr(opt, 'gpuid'):
cuda = len(opt.gpuid) >= 1
elif hasattr(opt, 'gpu'):
cuda = opt.gpu > -1
else:
cuda = False

if cuda:
if gpuid:
model.cuda()
generator.cuda()
else:
model.cpu()
generator.cpu()
model.generator = generator

return model

0 comments on commit 8df0e5e

Please sign in to comment.