Skip to content

Commit

Permalink
Add main method for xlnet text generation (#100)
Browse files Browse the repository at this point in the history
* Add main method for xlnet text generation

* Remove IPython dependency

* Add TopKSampleEmbeddingHelper and TopPSampleEmbeddingHelper support.

* Contract the imports

* Use default p=0.9

* Polish README

* Update xlnet README

* Update xlnet README

* Add torch.no_grad(). Fix OOM and tensor related issue

* Fix minor issues in decoder
  • Loading branch information
AvinashBukkittu committed Jul 12, 2019
1 parent c04990d commit 432a68f
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 42 deletions.
130 changes: 109 additions & 21 deletions examples/xlnet/README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
# XLNet: Pre-trained models and downstream applications
# XLNet for Classification and Generation

This is a Texar PyTorch implementation of the [XLNet model](https://github.com/zihangdai/xlnet), which supports loading
pre-trained model weights downloaded from the official release and building/fine-tuning downstream applications.

To summarize, this example showcases:

- Use of pre-trained XLNet models in Texar-PyTorch.
- Building and fine-tuning on downstream tasks.
- Use of Texar-PyTorch RecordData module for data loading and processing.
- [Fine-tuning XLNet for classification](#classification)
- [XLNet for text generation](#generation)
- [XLNet for other custom tasks](#extend-to-custom-tasks)

**Note:**
- This example has reproduced the reported results on STS-B and IMDB on GPUs. As per
- For classification, this example has reproduced the reported results on STS-B and IMDB on GPUs. As per
[the official repository](https://github.com/zihangdai/xlnet#memory-issue-during-finetuning), computational resources
(e.g., GPU memory) can affect the results.
- This example supports classification and GLUE datasets. Other datasets can be supported by adding respective data
modules. See [this section](https://github.com/huzecong/texar-xlnet#extend-to-custom-tasks).
modules. See [this section](#extend-to-custom-tasks).

**Future Work:**
- Distributed / Multi-GPU training.
- Fine-tuning on SQuAD & RACE datasets.
- *Please propose an issue for what you expect*

## Quickstart
## Prerequisite

### Install dependencies
#### Install dependencies

Apart from requiring Texar-PyTorch, you should also satisfy dependencies in `requirements.txt` by running:
```bash
Expand All @@ -32,15 +33,17 @@ pip install -r requirements.txt
Specifically, TensorFlow is required to load the pre-trained models from the official release, and sentencepiece is
required for tokenization.

### Download pre-trained model
#### Download pre-trained model

```bash
sh scripts/download_model.sh
```

By default, the pre-trained model (XLNet-Large Cased) will be downloaded to `pretrained/xlnet_cased_L-24_H-1024_A-16`.

### Download dataset
## Classification

#### Download dataset

We will use the STS-B sentence pair relevance dataset as an example. Routines for other datasets are similar.

Expand All @@ -52,7 +55,7 @@ Data will be downloaded to `data/STS-B` along with other GLUE datasets.

Note that this is a regression task, the evaluation metric is Pearson's r correlation.

### Fine-tune the model
#### Fine-tune the model

To fine-tune the model on the dataset, run the following command:
```bash
Expand Down Expand Up @@ -156,7 +159,7 @@ Evaluating on test
Pearsonr: nan, loss: 9.1475
```

### Evaluate saved models
#### Evaluate saved models

To evaluate a saved model, run the following command:
```bash
Expand All @@ -166,17 +169,102 @@ python xlnet_classification_main.py \
--mode eval
```

## Text generation
## Generation

Since XLNet is in essence a language model, it could be used to autoregressively generate text. We have also provided
an example to showcase text generation abilities of XLNet.
examples to showcase text generation abilities of XLNet:

- [Interactive mode (to generate samples with context)](#interactive-mode-to-generate-samples-with-context)
- [Non-interactive mode (to generate samples from scratch)](#non-interactive-mode-to-generate-samples-from-scratch)
- [IPython mode (to play with different decoding strategies)](#ipython-mode-to-play-with-different-decoding-strategies)

| WARNING: Samples are unfiltered and may contain offensive content. |
| --- |

#### Interactive mode (to generate samples with context)

This mode will initialize an interactive interface, which allows users to type in the context sentence. The model then generates continuation of the context. The example supports both Top-K and Top-P sample decoding.

```
python xlnet_generation_main.py --is_interactive \
--max_decoding_length=200 \
--temperature=0.7 \
--top_k=40
```

Here:

- `is_interactive`: Specifies interactive mode.
- `max_decoding_length`: The maximum number of tokens in the sample. **Note that this includes tokens in the context**.
- `nsamples`: Number of samples to generate for each input.

For *top-k decoding*:

- `temperature`: Softmax temperature of top-k sample decoding. Larger values (above 1.0) result in more random samples, while smaller values push the sampling distribution towards the argmax. Must be strictly greater than 0. Defaults to `0.7`.
- `top_k`: Number of top most likely candidates from a vocab distribution in each decoding step. Defaults to `40`.

For *top-p decoding*:
- `top_p`: Select tokens with cumulative probability of at most 'top_p' as candidates for sampling. Do not specify it if you want to use top-k decoding.


**Example input:**
```
Model input >>> Micheal Jordan is the greatest player in history !
```
**Example output:**
```
======================================== SAMPLE 1 ========================================
He was born George Jordan March 22, 1928, in Tobago, Trinidad and Tobago. Jordan walked super fast
and moved fast. He was also a tremendous downhill skier. He will go down in history with basketball as
an ancient foe.
Teleprint: This publication provides print service through the help of acertified Inter Print Printer.
Teleprint is intended for users who are not physical print service providers ("HSPs") or printers
who are not dealers of or in possession of services offered by a specific HP. Note allowed: Users
who are in possession of services offered by a specific HP are authorized to use high-speed inter print
services.
================================================================================
```

#### Non-interactive mode (to generate samples from scratch)

This mode generates a batch of samples from scratch.

```
python xlnet_generation_main.py
--nsamples=1 \
--batch_size=1 \
--max_decoding_len=100 \
--temperature=0.7 \
--top_k=40
```

Here:

- `nsamples`: Total number of samples to generate, must be dividable by the `batch_size`.
- `batch_size`: Each iteration generates `batch_size` number of samples.

**Example output:**

```
"A new government and a healthy economy have a chance to take this up."
After he said the election's outcome in the House was important and had helped to build
confidence in the House, former Ukip leader Nigel Farage spoke about working to boost
the economy, saying the vote for the "lefties" and others "were bad optics for Labour
in this way".
```

#### IPython mode (to play with different decoding strategies)

The IPython mode allows you to play with different decoding strategies (top-k, top-p, greedy, etc) and other hyperparameters.

To run the text generation, run the following command:
Install IPython, and run the following command to enter an interactive console.
```bash
python xlnet_generation_ipython.py
```
It is recommended to install IPython before running the command. If IPython is installed, you will enter an interactive
console in which you can perform sampling with different options. Here we show an example output:
Here we show an example output:
```
Generate text by calling: sample("<your prompt text>", ...).
For options, refer to `decode` method of `XLNetDecoder`.
Expand All @@ -201,16 +289,16 @@ unicate better, including giving them a "little bit of a leg up" on the English
many "English-speaking" ("in many respects") settlers became.
```

This text generation example is largely inspired by the works of: https://github.com/rusiaaman/XLNet-gen. Especially, we
*This text generation example is largely inspired by the works of: https://github.com/rusiaaman/XLNet-gen. Especially, we
borrowed the trick of [adding random text for padding](https://github.com/rusiaaman/XLNet-gen#methodology), so
shorter prompts will not suffer from lack of attentional context.
shorter prompts will not suffer from lack of attentional context.*

## Extend to custom tasks

The interfaces of Texar XLNet are designed to be extensible. You can use your own dataset, or use XLNet as a standalone
module in other tasks.

### Use your own dataset by writing a custom data processor
#### Use your own dataset by writing a custom data processor

It is easy to adapt the code to fine-tune XLNet on your custom dataset. To do this, you will need to write a custom
data processor inheriting `xlnet.data.DataProcessor`. For concrete examples, please refer to the built-in processors
Expand All @@ -233,7 +321,7 @@ decorator `@DataProcessor.register("task_name")`.

Now, simply import your processor into `run.py`, and run the training command with `--task` flags set to your task name.

### Use XLNet as a standalone module
#### Use XLNet as a standalone module

`xlnet.model.XLNet` can be used as a standalone module in a similar way to a Texar encoder. For convenience, we also
provide `XLNetClassifier` and `XLNetRegressor` for classification and regression tasks. Please refer to module
Expand Down
8 changes: 2 additions & 6 deletions examples/xlnet/xlnet/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, start_tokens: torch.LongTensor,
self._state_recompute_memory = recompute_memory
self._state_cache_len = cache_len
self._state_previous_inputs = list(
self.word_embed(start_tokens).unbind(dim=0))
self.word_embed(start_tokens).unbind(dim=0))[:-1]

if helper_type is None:
helper_type = tx.modules.SampleEmbeddingHelper
Expand All @@ -180,7 +180,7 @@ def forward(self, start_tokens: torch.LongTensor,
_, memory = self._forward(
memory=memory, cache_len=cache_len,
**self.create_input(
self._state_previous_inputs[:-1], initial=True))
self._state_previous_inputs, initial=True))
start_tokens = start_tokens[-1]

helper_kwargs.update(
Expand All @@ -200,8 +200,4 @@ def forward(self, start_tokens: torch.LongTensor,
if print_steps:
print("\033[2K\r", end='')

# Remove the first character since we treat it as warm-up.
output = XLNetDecoderOutput(
output.logits[:, 1:], output.sample_id[:, 1:])

return output, new_memory
24 changes: 12 additions & 12 deletions examples/xlnet/xlnet_generation_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def main():
# prompts. Refer to https://github.com/rusiaaman/XLNet-gen for the rationale
# behind this.
pad_txt = """
In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
narrates the remainder of the story. 1883 Western Siberia, a young
Grigori Rasputin is asked by his father and a group of men to perform
magic. Rasputin has a vision and denounces one of the men as a horse
thief. Although his father initially slaps him for making such an
accusation, Rasputin watches as the man is chased outside and beaten.
Twenty years later, Rasputin sees a vision of the Virgin Mary,
prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. """
Texar-PyTorch is an open-source toolkit based on PyTorch, aiming to
support a broad set of machine learning, especially text generation
tasks, such as machine translation, dialog, summarization, content
manipulation, language modeling, and so on. Texar is designed for both
researchers and practitioners for fast prototyping and
experimentation.
With the design goals of modularity, versatility, and extensibility in
mind, Texar extracts the common patterns underlying the diverse tasks
and methodologies, creates a library of highly reusable modules and
functionalities, and facilitates arbitrary model architectures and
algorithmic paradigms. """
pad_ids = tokenize_fn(pad_txt)
pad_ids.append(xlnet.data.utils.EOD_ID)

Expand All @@ -48,7 +48,7 @@ def split_by(xs, y):
yield xs[p:]

@torch.no_grad()
def sample(text: str, length: int = 200, n_samples=3, **kwargs):
def sample(text: str, length: int = 100, n_samples=3, **kwargs):
print("=== Prompt ===")
print(text)
model.eval()
Expand Down

0 comments on commit 432a68f

Please sign in to comment.