Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jun 20, 2019
1 parent db3032e commit 81c6775
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 4 deletions.
44 changes: 44 additions & 0 deletions README.md
Expand Up @@ -160,6 +160,50 @@ total_steps, warmup_steps = calc_train_steps(
optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5)
```

### Extract Features

You can use helper function `extract_embeddings` if the features of tokens or sentences (without further tuning) are what you need. To extract the features of all tokens:

```python
from keras_bert import extract_embeddings

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'
texts = ['all work and no play', 'makes jack a dull boy~']

embeddings = extract_embeddings(model_path, texts)
```

The returned result is a list with the same length as texts. Each item in the list is a numpy array truncated by the length of the input. The shapes of outputs in this example are `(8, 768)` and `(9, 768)`.

When the inputs are paired-sentences, and you need the outputs of `NSP` and max-pooling of the last 4 layers:

```python
from keras_bert import extract_embeddings, POOL_NSP, POOL_MAX

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'
texts = [
('all work and no play', 'makes jack a dull boy'),
('makes jack a dull boy', 'all work and no play'),
]

embeddings = extract_embeddings(model_path, texts, output_layer_num=4, poolings=[POOL_NSP, POOL_MAX])
```

There are no token features in the results. The outputs of `NSP` and max-pooling will be concatenated with the final shape `(768 x 4 x 2,)`.

The second argument in the helper function is a generator. To extract features from file:

```python
import codecs
from keras_bert import extract_embeddings

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'

with codecs.open('xxx.txt', 'r', 'utf8') as reader:
texts = map(lambda x: x.strip(), reader)
embeddings = extract_embeddings(model_path, texts)
```

### Use `tensorflow.python.keras`

Add `TF_KERAS=1` to environment variables to use `tensorflow.python.keras`.
Expand Down
44 changes: 44 additions & 0 deletions README.zh-CN.md
Expand Up @@ -173,6 +173,50 @@ optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5)

`training``True`的情况下,输入包含三项:token下标、segment下标、被masked的词的模版。当`training``False`时输入只包含前两项。位置下标由于是固定的,会在模型内部生成,不需要手动再输入一遍。被masked的词的模版在输入被masked的词是值为1,否则为0。

### 提取特征

如果不需要微调,只想提取词/句子的特征,则可以使用`extract_embeddings`来简化流程。如提取每个句子对应的全部词的特征:

```python
from keras_bert import extract_embeddings

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'
texts = ['all work and no play', 'makes jack a dull boy~']

embeddings = extract_embeddings(model_path, texts)
```

返回的结果是一个list,长度和输入文本的个数相同,每个元素都是numpy的数组,默认会根据输出的长度进行裁剪,所以在这个例子中输出的大小分别为`(8, 768)``(9, 768)`

如果输入是成对的句子,想使用最后4层特征,且提取`NSP`位输出和max-pooling的结果,则可以用:

```python
from keras_bert import extract_embeddings, POOL_NSP, POOL_MAX

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'
texts = [
('all work and no play', 'makes jack a dull boy'),
('makes jack a dull boy', 'all work and no play'),
]

embeddings = extract_embeddings(model_path, texts, output_layer_num=4, poolings=[POOL_NSP, POOL_MAX])
```

输出结果中不再包含词的特征,`NSP`和max-pooling的输出会拼接在一起,每个numpy数组的大小为`(768 x 4 x 2,)`

第二个参数接受的是一个generator,如果想读取文件并生成特征,可以用下面的方法:

```python
import codecs
from keras_bert import extract_embeddings

model_path = 'xxx/yyy/uncased_L-12_H-768_A-12'

with codecs.open('xxx.txt', 'r', 'utf8') as reader:
texts = map(lambda x: x.strip(), reader)
embeddings = extract_embeddings(model_path, texts)
```

### 使用`tensorflow.python.keras`

在环境变量里加入`TF_KERAS=1`可以启用`tensorflow.python.keras`。加入`TF_EAGER=1`可以启用eager execution。在Keras本身没去支持之前,如果想使用tensorflow 2.0则必须使用`TF_KERAS=1`
Expand Down
25 changes: 22 additions & 3 deletions keras_bert/util.py
Expand Up @@ -31,7 +31,14 @@ def get_checkpoint_paths(model_path):
return CheckpointPaths(config_path, checkpoint_path, vocab_path)


def extract_embeddings_generator(model, texts, poolings=None, vocabs=None, cased=False, batch_size=4, cut_embed=True):
def extract_embeddings_generator(model,
texts,
poolings=None,
vocabs=None,
cased=False,
batch_size=4,
cut_embed=True,
output_layer_num=1):
"""Extract embeddings from texts.
:param model: Path to the checkpoint or built model without MLM and NSP.
Expand All @@ -42,13 +49,16 @@ def extract_embeddings_generator(model, texts, poolings=None, vocabs=None, cased
:param cased: Whether it is cased for tokenizer.
:param batch_size: Batch size.
:param cut_embed: The computed embeddings will be cut based on their input lengths.
:param output_layer_num: The number of layers whose outputs will be concatenated as a single output.
Only available when `model` is a path to checkpoint.
:return: A list of numpy arrays representing the embeddings.
"""
if isinstance(model, (str, type(u''))):
paths = get_checkpoint_paths(model)
model = load_trained_model_from_checkpoint(
config_file=paths.config,
checkpoint_file=paths.checkpoint,
output_layer_num=output_layer_num,
)
vocabs = load_vocabulary(paths.vocab)

Expand Down Expand Up @@ -111,7 +121,14 @@ def _pad_inputs():
yield output


def extract_embeddings(model, texts, poolings=None, vocabs=None, cased=False, batch_size=4, cut_embed=True):
def extract_embeddings(model,
texts,
poolings=None,
vocabs=None,
cased=False,
batch_size=4,
cut_embed=True,
output_layer_num=1):
"""Extract embeddings from texts.
:param model: Path to the checkpoint or built model without MLM and NSP.
Expand All @@ -122,8 +139,10 @@ def extract_embeddings(model, texts, poolings=None, vocabs=None, cased=False, ba
:param cased: Whether it is cased for tokenizer.
:param batch_size: Batch size.
:param cut_embed: The computed embeddings will be cut based on their input lengths.
:param output_layer_num: The number of layers whose outputs will be concatenated as a single output.
Only available when `model` is a path to checkpoint.
:return: A list of numpy arrays representing the embeddings.
"""
return [embedding for embedding in extract_embeddings_generator(
model, texts, poolings, vocabs, cased, batch_size, cut_embed
model, texts, poolings, vocabs, cased, batch_size, cut_embed, output_layer_num
)]
10 changes: 9 additions & 1 deletion tests/test_util.py
Expand Up @@ -2,6 +2,7 @@
from __future__ import unicode_literals
import unittest
import os
import codecs
from keras_bert.backend import keras
from keras_bert import get_model, POOL_NSP, POOL_MAX, POOL_AVE, extract_embeddings

Expand Down Expand Up @@ -59,9 +60,10 @@ def test_extract_embeddings_multi_pooling(self):
('makes jack a dull boy', 'all work and no play'),
],
poolings=[POOL_NSP, POOL_MAX, POOL_AVE],
output_layer_num=2,
)
self.assertEqual(2, len(embeddings))
self.assertEqual((12,), embeddings[0].shape)
self.assertEqual((24,), embeddings[0].shape)

def test_extract_embeddings_invalid_pooling(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -104,3 +106,9 @@ def test_extract_embeddings_variable_lengths(self):
self.assertEqual(2, len(embeddings))
self.assertEqual((10, 13), embeddings[0].shape)
self.assertEqual((14, 13), embeddings[1].shape)

def test_extract_embeddings_from_file(self):
with codecs.open(os.path.join(self.model_path, 'vocab.txt'), 'r', 'utf8') as reader:
texts = map(lambda x: x.strip(), reader)
embeddings = extract_embeddings(self.model_path, texts)
self.assertEqual(15, len(embeddings))

0 comments on commit 81c6775

Please sign in to comment.