Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Update documents
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jul 15, 2019
1 parent 80ebe70 commit 027e249
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 2 deletions.
48 changes: 47 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,50 @@ pip install keras-xlnet

## Usage

TODO
### Load Pretrained Checkpoints

```python
import os
from keras_xlnet import load_trained_model_from_checkpoint

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
)
model.summary()
```

Arguments `batch_size`, `memory_len` and `target_len` are maximum sizes used for initialization of memories. The model used for training a language model is returned if `in_train_phase` is `True`, otherwise a model used for fine-tuning will be returned.

### About I/O

#### `in_train_phase` is `False`

3 inputs:

* IDs of tokens, with shape `(batch_size, target_len)`.
* IDs of segments, with shape `(batch_size, target_len)`.
* Length of memories, with shape `(batch_size, 1)`.

1 output:

* The feature for each token, with shape `(batch_size, target_len, units)`.

#### `in_train_phase` is `True`

4 inputs:

* IDs of tokens, with shape `(batch_size, target_len)`.
* IDs of segments, with shape `(batch_size, target_len)`.
* Length of memories, with shape `(batch_size, 1)`.
* Masks of tokens, with shape `(batch_size, target_len)`.

1 output:

* The probability of each token in each position, with shape `(batch_size, target_len, num_token)`.
48 changes: 48 additions & 0 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,51 @@ pip install keras-xlnet
```

## 使用

### 加载预训练检查点

```python
import os
from keras_xlnet import load_trained_model_from_checkpoint

checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'

model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
)
model.summary()
```

参数`batch_size``memory_len``target_len`用于初始化记忆单元,代表最大尺寸,实际属于可以小于对应数值。如果`in_train_phase``True`会返回一个用于训练语言模型的模型,否则返回一个用于fine-tuning的模型。

### 关于输入输出

#### `in_train_phase``False`

3个输入:

* 词的ID,形状为`(batch_size, target_len)`
* 段落的ID,形状为`(batch_size, target_len)`
* 历史记忆的长度,形状为`(batch_size, 1)`

1个输出:

* 每个词的特征,形状为`(batch_size, target_len, units)`

#### `in_train_phase``True`

4个输入,前三个和`in_train_phase``False`时相同:

* 词的ID,形状为`(batch_size, target_len)`
* 段落的ID,形状为`(batch_size, target_len)`
* 历史记忆的长度,形状为`(batch_size, 1)`
* 被遮罩的词的蒙版,形状为`(batch_size, target_len)`

1个输出:

* 每个位置每个词的概率,形状为`(batch_size, target_len, num_token)`
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='keras-xlnet',
version='0.0.0',
version='0.0.1',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-xlnet',
license='MIT',
Expand Down

0 comments on commit 027e249

Please sign in to comment.