Skip to content

Commit

Permalink
Add sentence_classifier example (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpengzhi authored and huzecong committed Aug 27, 2019
1 parent 625eea7 commit dfe0332
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ More examples are continuously added...
### Classifier / Sequence Prediction ###

* [bert](./bert): Pre-trained BERT model for text representation
* [sentence_classifier](./sentence_classifier): Basic CNN-based sentence classifier
* [xlnet](./xlnet): Pre-trained XLNet model for text classification/regression

---
Expand All @@ -45,5 +46,6 @@ More examples are continuously added...
### Classification ###

* [bert](./bert): Pre-trained BERT model for text representation
* [sentence_classifier](./sentence_classifier): Basic CNN-based sentence classifier
* [xlnet](./xlnet): Pre-trained XLNet model for text classification/regression

37 changes: 37 additions & 0 deletions examples/sentence_classifier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Sentence Sentiment Classifier #

This example builds sentence convolutional classifier, and trains on [SST data](https://nlp.stanford.edu/sentiment/index.html). The example configuration [config_kim.py](./config_kim.py) corresponds to the paper
[(Kim) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf).

The example shows:

* Contruction of simple model, involving the `Embedder` and `Conv1DClassifier`.
* Use of Texar `MultiAlignedData` to read parallel text and label data.

## Usage ##

Use the following command to download and prepare the SST binary data:

```
python sst_data_preprocessor.py [--data-path ./data]
```

Here

* `--data-path` specifies the directory to store the SST data. If the data files do not exist, the program will automatically download, extract, and pre-process the data.

The following command trains the model with Kim's configuration:

```
python clas_main.py --config config_kim
```

Here:

* `--config` specifies the config file to use. E.g., the above use the configuration defined in [config_kim.py](./config_kim.py)

The model will begin training and evaluating on the validation data, and will evaluate on the test data after every epoch if a valid accuracy is obtained.

## Results ##

The model achieves around `83%` test set accuracy.
128 changes: 128 additions & 0 deletions examples/sentence_classifier/clas_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for building a sentence convolutional classifier.
Use `./sst_data_preprocessor.py` to download and clean the SST binary data.
To run:
$ python clas_main.py --config=config_kim
"""

from typing import Any, Dict, Tuple

import argparse
import importlib

import torch
import torch.nn as nn
import torch.nn.functional as F

import texar.torch as tx

parser = argparse.ArgumentParser()
parser.add_argument(
'--config', type=str, default='config_kim',
help='The config to use.')
args = parser.parse_args()

config = importlib.import_module(args.config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SentenceClassifier(nn.Module):

def __init__(self, vocab_size: int, max_seq_length: int,
emb_dim: int, hparams: Dict[str, Any]):
super().__init__()

self.embedder = tx.modules.WordEmbedder(
vocab_size=vocab_size, hparams=hparams['embedder'])
self.classifier = tx.modules.Conv1DClassifier(
in_channels=max_seq_length,
in_features=emb_dim, hparams=hparams['classifier'])

def forward(self, batch: tx.data.Batch) -> \
Tuple[torch.Tensor, torch.Tensor]:
logits, pred = self.classifier(
self.embedder(batch['sentence_text_ids']))
loss = F.cross_entropy(logits, batch['label'].long())
return pred, loss


def main():
# Data
train_data = tx.data.MultiAlignedData(config.train_data, device=device)
val_data = tx.data.MultiAlignedData(config.val_data, device=device)
test_data = tx.data.MultiAlignedData(config.test_data, device=device)
data_iterator = tx.data.TrainTestDataIterator(
train=train_data, val=val_data, test=test_data)

hparams = {
'embedder': config.emb,
'classifier': config.clas
}
model = SentenceClassifier(vocab_size=train_data.vocab('sentence').size,
max_seq_length=config.max_seq_length,
emb_dim=config.emb_dim,
hparams=hparams)
model.to(device)
train_op = tx.core.get_train_op(params=model.parameters(),
hparams=config.opt)

def _run_epoch(mode, epoch):

step = 0
avg_rec = tx.utils.AverageRecorder()
for batch in data_iterator:
pred, loss = model(batch)
if mode == "train":
loss.backward()
train_op()
accu = tx.evals.accuracy(batch['label'], pred)
step += 1
if step == 1 or step % 100 == 0:
print(f"epoch: {epoch:2} step: {step:4} accu: {accu:.4f}")

batch_size = batch['label'].size(0)
avg_rec.add([accu], batch_size)

return avg_rec.avg(0)

best_val_accu = -1
for epoch in range(config.num_epochs):
# Train
data_iterator.switch_to_train_data()
model.train()
train_accu = _run_epoch("train", epoch)

# Val
data_iterator.switch_to_val_data()
model.eval()
val_accu = _run_epoch("val", epoch)
print(f'epoch: {epoch:2} train accu: {train_accu:.4f} '
f'val accu: {val_accu:.4f}')

# Test
if val_accu > best_val_accu:
best_val_accu = val_accu
data_iterator.switch_to_test_data()
model.eval()
test_accu = _run_epoch("test", epoch)
print(f'test accu: {test_accu:.4f}')


if __name__ == '__main__':
main()
77 changes: 77 additions & 0 deletions examples/sentence_classifier/config_kim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sentence convolutional classifier config.
This is (approximately) the config of the paper:
(Kim) Convolutional Neural Networks for Sentence Classification
https://arxiv.org/pdf/1408.5882.pdf
"""

import copy

num_epochs = 15
max_seq_length = 56
emb_dim = 300

train_data = {
"batch_size": 50,
"datasets": [
{
"files": "./data/sst2.train.sentences.txt",
"vocab_file": "./data/sst2.vocab",
# Discards samples with length > 56
"max_seq_length": max_seq_length,
"length_filter_mode": "discard",
# Do not append BOS/EOS tokens to the sentences
"bos_token": "",
"eos_token": "",
"data_name": "sentence"
},
{
"files": "./data/sst2.train.labels.txt",
"data_type": "int",
"data_name": "label"
}
]
}
# The val and test data have the same config with the train data, except
# for the file names
val_data = copy.deepcopy(train_data)
val_data["datasets"][0]["files"] = "./data/sst2.dev.sentences.txt"
val_data["datasets"][1]["files"] = "./data/sst2.dev.labels.txt"
test_data = copy.deepcopy(train_data)
test_data["datasets"][0]["files"] = "./data/sst2.test.sentences.txt"
test_data["datasets"][1]["files"] = "./data/sst2.test.labels.txt"

# Word embedding
emb = {
"dim": emb_dim
}

# Classifier
clas = {
"num_conv_layers": 1,
"out_channels": 100,
"kernel_size": [3, 4, 5],
"conv_activation": "ReLU",
"pooling": "MaxPool1d",
"num_dense_layers": 0,
"dropout_conv": [1],
"dropout_rate": 0.5,
"num_classes": 2
}

# Optimization
# Just use the default config, e.g., Adam Optimizer
opt = {}
96 changes: 96 additions & 0 deletions examples/sentence_classifier/sst_data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preparing the SST2 dataset.
"""

from typing import Tuple

import argparse
import os
import re

import texar.torch as tx

parser = argparse.ArgumentParser()
parser.add_argument(
'--data-path', type=str, default='./data',
help="E.g., ./data/sst2.train.sentences.txt. If not exists, the directory "
"will be created and SST raw data will be downloaded.")
args = parser.parse_args()


def clean_sst_text(text: str) -> str:
"""Cleans tokens in the SST data, which has already been tokenized.
"""
text = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip().lower()


def transform_raw_sst(data_path: str, raw_filename: str, new_filename: str) -> \
Tuple[str, str]:
"""Transforms the raw data format to a new format.
"""
fout_x_name = os.path.join(data_path, new_filename + '.sentences.txt')
fout_x = open(fout_x_name, 'w', encoding='utf-8')
fout_y_name = os.path.join(data_path, new_filename + '.labels.txt')
fout_y = open(fout_y_name, 'w', encoding='utf-8')

fin_name = os.path.join(data_path, raw_filename)
with open(fin_name, 'r', encoding='utf-8') as fin:
for line in fin:
parts = line.strip().split()
label = parts[0]
sent = ' '.join(parts[1:])
sent = clean_sst_text(sent)
fout_x.write(sent + '\n')
fout_y.write(label + '\n')

return fout_x_name, fout_y_name


def prepare_data():
"""Preprocesses SST2 data.
"""
train_path = os.path.join(args.data_path, "sst.train.sentences.txt")
if not os.path.exists(train_path):
url = ('https://raw.githubusercontent.com/ZhitingHu/'
'logicnn/master/data/raw/')
files = ['stsa.binary.phrases.train', 'stsa.binary.dev',
'stsa.binary.test']
for fn in files:
tx.data.maybe_download(url + fn, args.data_path, extract=True)

fn_train, _ = transform_raw_sst(
args.data_path, 'stsa.binary.phrases.train', 'sst2.train')
transform_raw_sst(args.data_path, 'stsa.binary.dev', 'sst2.dev')
transform_raw_sst(args.data_path, 'stsa.binary.test', 'sst2.test')

vocab = tx.data.make_vocab(fn_train)
fn_vocab = os.path.join(args.data_path, 'sst2.vocab')
with open(fn_vocab, 'w', encoding='utf-8') as f_vocab:
for v in vocab:
f_vocab.write(v + '\n')

print('Preprocessing done: {}'.format(args.data_path))


def main():
"""Entrypoint.
"""
prepare_data()


if __name__ == '__main__':
main()

0 comments on commit dfe0332

Please sign in to comment.