Skip to content

Commit

Permalink
Merge pull request #371 from NVIDIA/jenkins_test
Browse files Browse the repository at this point in the history
glue fix for pretrained checkpoint
  • Loading branch information
yzhang123 committed Feb 14, 2020
2 parents 3e04e09 + b43d64d commit 6b3be90
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions examples/nlp/glue_benchmark/glue_benchmark_with_bert.py
Expand Up @@ -61,9 +61,10 @@
"""

import argparse
import json
import os

from transformers import BertConfig

import nemo.collections.nlp as nemo_nlp
import nemo.core as nemo_core
from nemo import logging
Expand Down Expand Up @@ -187,6 +188,7 @@
add_time_to_log_dir=True,
)


if args.bert_checkpoint is None:
""" Use this if you're using a standard BERT model.
To see the list of pretrained models, call:
Expand All @@ -205,14 +207,31 @@
tokenizer = NemoBertTokenizer(args.pretrained_bert_model)
else:
raise ValueError(f"received unexpected tokenizer '{args.tokenizer}'")

if args.bert_config is not None:
with open(args.bert_config) as json_file:
config = json.load(json_file)
model = nemo_nlp.nm.trainables.huggingface.BERT(**config)
config = BertConfig.from_json_file(args.bert_config).to_dict()
args.vocab_size = config['vocab_size']
args.hidden_size = config['hidden_size']
args.num_hidden_layers = config['num_hidden_layers']
args.num_attention_heads = config['num_attention_heads']
args.intermediate_size = config['intermediate_size']
args.hidden_act = config['hidden_act']
args.max_seq_length = config['max_position_embeddings']

model = nemo_nlp.nm.trainables.huggingface.BERT(
vocab_size=args.vocab_size,
num_hidden_layers=args.num_hidden_layers,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
intermediate_size=args.intermediate_size,
max_position_embeddings=args.max_seq_length,
hidden_act=args.hidden_act,
)
logging.info(f"using {args.bert_config}")
else:
model = nemo_nlp.nm.trainables.huggingface.BERT(pretrained_model_name=args.pretrained_bert_model)

model.restore_from(args.bert_checkpoint)
logging.info(f"model resotred from {args.bert_checkpoint}")

hidden_size = model.hidden_size

Expand Down

0 comments on commit 6b3be90

Please sign in to comment.