Skip to content

Commit

Permalink
Update gpt2 README
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Apr 13, 2019
1 parent ecc371b commit 0061bdb
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions examples/gpt-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ Texar's `TransformerDecoder` (and other RNN-based decoders) easily supports comm

**For example**, after creating the language model
```python
decoder = TransformerDecoder(embedder, hparams=gpt2_hparams)
def _embedding_fn(ids, times):
return word_embedder(ids) + pos_embedder(times)

decoder = TransformerDecoder(
output_layer=tf.transpose(word_embedder.embedding),
hparams=gpt2_hparams)
```
We can do

Expand All @@ -111,17 +116,18 @@ output, output_length = decoder(
context=ctx,
context_sequence_length=ctx_len,
decoding_strategy='infer_greedy',
end_token=end_token)
end_token=end_token
embedding=_embedding_fn)

sample_id = output.sample_id
logits = output.logits
```

**Ex. Use 2): Top-k sample decoding**

```python
```python
topk_helper = tx.modules.TopKSampleEmbeddingHelper(
embedding=embedder,
embedding=_embedding_fn,
start_tokens=ctx[:,0],
end_token=end_token,
top_k=20,
Expand All @@ -136,10 +142,12 @@ output, output_length = decoder(
**Ex. Use 3): Fine-tuning for conditional generation**

```python
tgt_embed = word_embedder(truth_target[:, :-1]) + pos_embedder(sequence_length=tgt_len-1)

output = decoder(
memory=source_hidden_states,
memory_sequence_length=src_len,
inputs=embedder(truth_target[:, :-1]),
inputs=tgt_embed,
decoding_strategy='train_greedy') # teacher-forcing decoding

loss = tx.losses.sequence_sparse_softmax_cross_entropy(
Expand Down

0 comments on commit 0061bdb

Please sign in to comment.