Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Update GLUE results
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Aug 3, 2019
1 parent c9d41c4 commit 7dc07ad
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 8 deletions.
22 changes: 21 additions & 1 deletion README.md
Expand Up @@ -13,7 +13,7 @@

\[[中文](https://github.com/CyberZHG/keras-xlnet/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-xlnet/blob/master/README.md)\]

Unofficial implementation of [XLNet](https://arxiv.org/pdf/1906.08237).
Unofficial implementation of [XLNet](https://arxiv.org/pdf/1906.08237). [Embedding extraction](demo/extract/token_embeddings.py) and [embedding extract with memory](demo/extract/token_embeddings_with_memory.py) show how to get the results of the last transformer layer using pre-trained checkpoints.

## Install

Expand All @@ -23,6 +23,24 @@ pip install keras-xlnet

## Usage

### Fine-tuning on GLUE

Click the task name to see the demos:

|Task Name |Metrics |Approximate Results on Dev Set|
|:-------------------------------|:----------------------------:|----:|
|[CoLA](demo/GLUE/CoLA/cola.py) |Matthew Corr. |52 |
|[SST-2](demo/GLUE/SST-2/sst2.py)|Accuracy |93 |
|[MRPC](demo/GLUE/MRPC/mrpc.py) |Accuracy/F1 |86/89|
|[STS-B](demo/GLUE/STS-B/stsb.py)|Pearson Corr. / Spearman Corr.|86/87|
|[QQP](demo/GLUE/QQP/qqp.py) |Accuracy/F1 |90/86|
|[MNLI](demo/GLUE/MNLI/mnli.py) |Accuracy |84/84|
|[QNLI](demo/GLUE/QNLI/qnli.py) |Accuracy |86 |
|[RTE](demo/GLUE/RTE/rte.py) |Accuracy |64 |
|[WNLI](demo/GLUE/WNLI/wnli.py) |Accuracy |56 |

(Only 0s are predicted in WNLI dataset)

### Load Pretrained Checkpoints

```python
Expand All @@ -47,6 +65,8 @@ Arguments `batch_size`, `memory_len` and `target_len` are maximum sizes used for

### About I/O

**Note that** `shuffle` should be `False` in either `fit` or `fit_generator` if memories are used.

#### `in_train_phase` is `False`

3 inputs:
Expand Down
22 changes: 21 additions & 1 deletion README.zh-CN.md
Expand Up @@ -13,7 +13,7 @@

\[[中文](https://github.com/CyberZHG/keras-xlnet/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-xlnet/blob/master/README.md)|[通用问题](https://github.com/CyberZHG/summary/blob/master/QA.md)\]

[XLNet](https://arxiv.org/pdf/1906.08237)的非官方实现。
[XLNet](https://arxiv.org/pdf/1906.08237)的非官方实现。[嵌入提取](demo/extract/token_embeddings.py)[有记忆的嵌入提取](demo/extract/token_embeddings_with_memory.py)展示了如何加载预训练检查点并得到transformer的输出特征。

## 安装

Expand All @@ -23,6 +23,24 @@ pip install keras-xlnet

## 使用

### GLUE微调

点击任务名可以查看训练样例:

|任务名 |指标 |验证集上大致结果|
|:-------------------------------|:----------------------------:|----:|
|[CoLA](demo/GLUE/CoLA/cola.py) |Matthew Corr. |52 |
|[SST-2](demo/GLUE/SST-2/sst2.py)|Accuracy |93 |
|[MRPC](demo/GLUE/MRPC/mrpc.py) |Accuracy/F1 |86/89|
|[STS-B](demo/GLUE/STS-B/stsb.py)|Pearson Corr. / Spearman Corr.|86/87|
|[QQP](demo/GLUE/QQP/qqp.py) |Accuracy/F1 |90/86|
|[MNLI](demo/GLUE/MNLI/mnli.py) |Accuracy |84/84|
|[QNLI](demo/GLUE/QNLI/qnli.py) |Accuracy |86 |
|[RTE](demo/GLUE/RTE/rte.py) |Accuracy |64 |
|[WNLI](demo/GLUE/WNLI/wnli.py) |Accuracy |56 |

(注意:WNLI数据集上只输出了0,不是一个正常结果)

### 加载预训练检查点

```python
Expand All @@ -47,6 +65,8 @@ model.summary()

### 关于输入输出

**注意**:依赖记忆时输入有序,一定不能打乱输入顺序,`fit``fit_generator``shuffle`应该为`False`

#### `in_train_phase``False`

3个输入:
Expand Down
11 changes: 6 additions & 5 deletions demo/GLUE/QNLI/qnli.py
Expand Up @@ -42,20 +42,21 @@ def __getitem__(self, index):


def generate_sequence(path):
tokens, classes = [], []
tokens, segments, classes = [], [], []
df = pd.read_csv(path, sep='\t', error_bad_lines=False)
for _, row in df.iterrows():
text_a, text_b, cls = row['question'], row['sentence'], row['label']
if not isinstance(text_a, str) or not isinstance(text_b, str) or cls not in CLASSES:
continue
encoded_a, encoded_b = tokenizer.encode(text_a)[:20], tokenizer.encode(text_b)[:77]
encoded = encoded_a + [tokenizer.SYM_SEP] + encoded_b + [tokenizer.SYM_SEP]
segment = [0] * (len(encoded_a) + 1) + [1] * (len(encoded_b) + 1) + [2]
encoded = [tokenizer.SYM_PAD] * (SEQ_LEN - 1 - len(encoded)) + encoded + [tokenizer.SYM_CLS]
segment = [-1] * (SEQ_LEN - len(segment)) + segment
tokens.append(encoded)
segments.append(segment)
classes.append(CLASSES[cls])
tokens, classes = np.array(tokens), np.array(classes)
segments = np.zeros_like(tokens)
segments[:, -1] = 1
tokens, segments, classes = np.array(tokens), np.array(segments), np.array(classes)
lengths = np.zeros_like(tokens[:, :1])
return DataSequence([tokens, segments, lengths], classes)

Expand Down Expand Up @@ -100,7 +101,7 @@ def generate_sequence(path):
generator=train_seq,
validation_data=dev_seq,
epochs=EPOCH,
callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)],
callbacks=[keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5)],
)

model.save_weights(MODEL_NAME)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -11,7 +11,7 @@

setup(
name='keras-xlnet',
version='0.13.0',
version='0.14.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-xlnet',
license='MIT',
Expand Down

0 comments on commit 7dc07ad

Please sign in to comment.