Skip to content

Commit

Permalink
Resolve #67 Polish GPT-2 example(#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomNong authored and huzecong committed Jun 30, 2019
1 parent 77f3b39 commit 5a90686
Showing 1 changed file with 57 additions and 68 deletions.
125 changes: 57 additions & 68 deletions examples/gpt-2/gpt2_generate_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,70 +29,56 @@
from utils import model_utils, processor

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint',
type=str,
default=None,
help="Model checkpoint to load model weights from. Use "
"`--pretrain_checkpoint` instead if loading OpenAI "
"pretrained checkpoint.")
parser.add_argument('--pretrain_checkpoint',
type=str,
default="gpt2_pretrained_models/model_117M/model.ckpt",
help="OpenAI pretrained model checkpoint. Ignored if "
"'--checkpoint' is specified.")
parser.add_argument('--pretrain_model_dir',
type=str,
default="gpt2_pretrained_models/model_117M",
help="The directory of pretrained model, for loading "
"vocabuary, etc.")
parser.add_argument('--seed',
type=int,
default=None,
help="Random seed.")
parser.add_argument('--nsamples',
type=int,
default=1,
help="The number of samples per input.")
parser.add_argument('--batch_size',
type=int,
default=1,
help="The batch size of input.")
parser.add_argument('--max_decoding_length',
type=int,
default=128,
help="The maximun length of generated text.")
parser.add_argument('--temperature',
type=float,
default=0.7,
help="Softmax temperature for top-k sample decoding. "
"Must be strictly greater than 0. Defaults to 0.7.")
parser.add_argument('--top_k',
type=int,
default=40,
help="The number of top most likely candidates from a "
"vocab distribution.")
parser.add_argument('--is_interactive',
action='store_true',
help="Interactive mode or not.")
parser.add_argument('--config_type',
type=str,
default="texar",
help="The configuration file type. Set to 'json' if the "
"GPT-2 config file is in the same type of the "
"official GPT-2 config file. Set to 'texar' "
"if GPT-2 config file is in Texar type.")
parser.add_argument('--config_model',
type=str,
default="configs.config_model_117M",
help="The model configuration file to configure the "
"model. The config file type is define by the "
"'config_type',it be of texar type or json type."
"For '--config_type=json', set the json "
"config file path like: '--config_model "
"gpt2_pretrained_models/model_117M/hparams.json';"
"For '--config_type=texar', set the texar "
"config file like: "
"'--config_model configs.config_model_117M'.")
parser.add_argument(
'--checkpoint', type=str, default=None,
help="Model checkpoint to load model weights from. Use "
"`--pretrain_checkpoint` instead if loading OpenAI pretrained "
"checkpoint.")
parser.add_argument(
'--pretrain_checkpoint', type=str,
default="gpt2_pretrained_models/model_117M/model.ckpt",
help="OpenAI pretrained model checkpoint. Ignored if '--checkpoint' "
"is specified.")
parser.add_argument(
'--pretrain_model_dir', type=str,
default="gpt2_pretrained_models/model_117M",
help="The directory of pretrained model, for loading vocabuary, etc.")
parser.add_argument(
'--seed', type=int, default=None,
help="Random seed.")
parser.add_argument(
'--nsamples', type=int, default=1,
help="The number of samples per input.")
parser.add_argument(
'--batch_size', type=int, default=1,
help="The batch size of input.")
parser.add_argument(
'--max_decoding_length', type=int, default=128,
help="The maximun length of generated text.")
parser.add_argument(
'--temperature', type=float, default=0.7,
help="Softmax temperature for top-k sample decoding. Must be strictly "
"greater than 0. Defaults to 0.7.")
parser.add_argument(
'--top_k', type=int, default=40,
help="The number of top most likely candidates from a vocab distribution.")
parser.add_argument(
'--is_interactive', action='store_true',
help="Interactive mode or not.")
parser.add_argument(
'--config_type', type=str, default="texar",
help="The configuration file type. Set to 'json' if the GPT-2 config file "
"is in the same type of the official GPT-2 config file. Set to "
"'texar' if GPT-2 config file is in Texar type.")
parser.add_argument(
'--config_model', type=str, default="configs.config_model_117M",
help="The model configuration file to configure the model. The config "
"file type is define by the 'config_type',it be of texar type "
"or json type. For '--config_type=json', set the json config "
"file path like: "
"'--config_model gpt2_pretrained_models/model_117M/hparams.json';"
"For '--config_type=texar', set the texar config file like: "
"'--config_model configs.config_model_117M'.")

args = parser.parse_args()

Expand All @@ -119,8 +105,8 @@ def __init__(self, gpt2_config, top_k, temperature):
self._embedding_fn = lambda x, y: (
self.word_embedder(x) + self.pos_embedder(y))

def forward(self, start_tokens, end_token, context, context_sequence_length,
max_decoding_length):
def forward(self, start_tokens, end_token, context,
context_sequence_length, max_decoding_length):
helper = tx.modules.TopKSampleEmbeddingHelper(
embedding=self._embedding_fn,
start_tokens=start_tokens,
Expand Down Expand Up @@ -195,13 +181,16 @@ def run_model():
print('Input should not be empty!')
raw_text = input("Model input >>> ")
except EOFError:
print("EOF entered, quitting.")
exit(0)

context_tokens = proc.encode(raw_text)
context = torch.tensor(
[context_tokens for _ in range(batch_size)], device=device)
[context_tokens for _ in range(batch_size)],
device=device)
context_length = torch.tensor(
[len(context_tokens) for _ in range(batch_size)], device=device)
[len(context_tokens) for _ in range(batch_size)],
device=device)

start_tokens = context[:, 0]

Expand Down

0 comments on commit 5a90686

Please sign in to comment.