Skip to content

Commit

Permalink
Merge pull request #96 from TomNong/GPT-2-refined
Browse files Browse the repository at this point in the history
Gpt 2 refined
  • Loading branch information
ZhitingHu committed Mar 2, 2019
2 parents fe8373e + c62e6a1 commit fc15fa2
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,6 @@ simple-examples.tgz
/examples/bert/bert_pretrained_models/
!/examples/bert/bert_pretrained_models/download_model.sh
/examples/bert/output
/examples/gpt-2/gpt2_pretrained_models/
!/examples/gpt-2/gpt2_pretrained_models/download_model.sh

4 changes: 2 additions & 2 deletions examples/gpt-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ In sum, this example showcases:

Download the GPT-2 model checkpoint with the following command:
```
sh gpt2_pretrained_models/download_model.sh 117M
sh gpt2_pretrained_models/download_model.sh model_117M
```
By default, it will download a pretrained model named `117M` to `gpt2_pretrained_models/`.
By default, it will download a pretrained model named `model_117M` to `gpt2_pretrained_models/`.

### Usage
| WARNING: Samples are unfiltered and may contain offensive content. |
Expand Down
3 changes: 3 additions & 0 deletions examples/gpt-2/configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Configuration files of GPT-2 models in Texar style.

For example, `config_model.py` is the Texar configuration file corresponding to the `model_117M` model downloaded from [GPT-2 official release](https://github.com/openai/gpt-2).
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Texar config file of the GPT-2 117M model.
"""Texar config file of the GPT-2 model_117M model.
"""

vocab_size = 50257
Expand Down
3 changes: 0 additions & 3 deletions examples/gpt-2/gpt2_config_lib/README.md

This file was deleted.

37 changes: 23 additions & 14 deletions examples/gpt-2/gpt2_generate_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from texar.modules.decoders.transformer_decoders import TransformerDecoder

from utils import model_utils, processor
import configs

# pylint: disable=invalid-name, too-many-locals, too-many-statements, no-member
# pylint: disable=too-many-branches
Expand All @@ -34,10 +35,9 @@

FLAGS = flags.FLAGS

flags.DEFINE_string("checkpoint", "gpt2_pretrained_models/117M/model.ckpt",
flags.DEFINE_string("checkpoint", "gpt2_pretrained_models/"
"model_117M/model.ckpt",
"Model checkpoint to load model weights from.")
flags.DEFINE_string("config_gpt2", "117M",
"The of the GPT-2 config file to use.")
flags.DEFINE_integer("seed", None, "Random seed.")
flags.DEFINE_integer("nsamples", 1, "The number of samples per input.")
flags.DEFINE_integer("batch_size", 1, "The batch size of input.")
Expand All @@ -50,11 +50,20 @@
"The number of top most likely candidates from a vocab "
"distribution.")
flags.DEFINE_boolean("is_interactive", False, "Interactive mode or not.")
flags.DEFINE_string("config_format", "json",
"The configuration file format. Set to 'json' if the GPT-2 "
"config file is in the same format of the official GPT-2 "
flags.DEFINE_string("config_type", "texar",
"The configuration file type. Set to 'json' if the GPT-2 "
"config file is in the same type of the official GPT-2 "
"config file. Set to 'texar' if GPT-2 config file is in "
"Texar format.")
"Texar type.")
flags.DEFINE_string("config_model", "configs.config_model",
"The model configuration file to configure the model. "
"The config file type is define by the 'config_type',"
"it be of texar type or json type."
"For '--config_type=json', set the json config file path"
"like: '--config_model gpt2_pretrained_models/model_117M/"
"hparams.json';"
"For '--config_type=texar', set the texar config file "
"like: '--config_model configs.config_model'.")

def main(_):
"""
Expand All @@ -69,24 +78,24 @@ def main(_):


ckpt_path = FLAGS.checkpoint
gpt2_config_dir = "gpt2_pretrained_models/%s" % FLAGS.config_gpt2
# Load GPT-2 model configuration
if FLAGS.config_format == "json":
if FLAGS.config_type == "json":
gpt2_config = model_utils.transform_gpt2_to_texar_config(
os.path.join(gpt2_config_dir, 'hparams.json'))
elif FLAGS.config_format == 'texar':
FLAGS.config_model)
elif FLAGS.config_type == 'texar':
gpt2_config = importlib.import_module(
'gpt2_config_lib.config_model_{}'.format(FLAGS.config_gpt2))
FLAGS.config_model)
else:
raise ValueError('Unknown config_format.')
raise ValueError('Unknown config_type.')

assert max_decoding_length <= gpt2_config.decoder["position_size"], (
"max_decoding_length should be smaller than position size")
assert nsamples % batch_size == 0, (
"nsamples must be dividable by batch_size")

# Create a data pre-processor for, e.g., BPE encoding
proc = processor.get_encoder(gpt2_config_dir)
proc = processor.get_encoder(
"gpt2_pretrained_models/model_117M")

context = tf.placeholder(tf.int32, [batch_size, None])
context_length = tf.placeholder(tf.int32, [batch_size])
Expand Down
9 changes: 5 additions & 4 deletions examples/gpt-2/gpt2_pretrained_models/download_model.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/bin/sh

if [ "$#" -ne 1 ]; then
echo "You must enter the model name as a parameter, e.g.: sh gpt2_pretrained_models/download_model.sh 117M"
echo "You must enter the model name as a parameter, e.g.: sh gpt2_pretrained_models/download_model.sh model_117M"
exit 1
fi

model=$1

model_name=${model#*_}
mkdir -p gpt2_pretrained_models/$model
for filename in checkpoint encoder.json hparams.json model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta vocab.bpe; do
fetch=$model/$filename
fetch=$model_name/$filename
sub_path=$model/$filename
echo "Fetching $fetch"
curl --output gpt2_pretrained_models/$fetch https://storage.googleapis.com/gpt-2/models/$fetch
curl --output gpt2_pretrained_models/$sub_path https://storage.googleapis.com/gpt-2/models/$fetch
done
2 changes: 2 additions & 0 deletions examples/gpt-2/utils/processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
#
"""
Byte pair encoding utilities
Expand Down

0 comments on commit fc15fa2

Please sign in to comment.