-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sentence_classifier example (#174)
- Loading branch information
Showing
5 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Sentence Sentiment Classifier # | ||
|
||
This example builds sentence convolutional classifier, and trains on [SST data](https://nlp.stanford.edu/sentiment/index.html). The example configuration [config_kim.py](./config_kim.py) corresponds to the paper | ||
[(Kim) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf). | ||
|
||
The example shows: | ||
|
||
* Contruction of simple model, involving the `Embedder` and `Conv1DClassifier`. | ||
* Use of Texar `MultiAlignedData` to read parallel text and label data. | ||
|
||
## Usage ## | ||
|
||
Use the following command to download and prepare the SST binary data: | ||
|
||
``` | ||
python sst_data_preprocessor.py [--data-path ./data] | ||
``` | ||
|
||
Here | ||
|
||
* `--data-path` specifies the directory to store the SST data. If the data files do not exist, the program will automatically download, extract, and pre-process the data. | ||
|
||
The following command trains the model with Kim's configuration: | ||
|
||
``` | ||
python clas_main.py --config config_kim | ||
``` | ||
|
||
Here: | ||
|
||
* `--config` specifies the config file to use. E.g., the above use the configuration defined in [config_kim.py](./config_kim.py) | ||
|
||
The model will begin training and evaluating on the validation data, and will evaluate on the test data after every epoch if a valid accuracy is obtained. | ||
|
||
## Results ## | ||
|
||
The model achieves around `83%` test set accuracy. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Example for building a sentence convolutional classifier. | ||
Use `./sst_data_preprocessor.py` to download and clean the SST binary data. | ||
To run: | ||
$ python clas_main.py --config=config_kim | ||
""" | ||
|
||
from typing import Any, Dict, Tuple | ||
|
||
import argparse | ||
import importlib | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
import texar.torch as tx | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--config', type=str, default='config_kim', | ||
help='The config to use.') | ||
args = parser.parse_args() | ||
|
||
config = importlib.import_module(args.config) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class SentenceClassifier(nn.Module): | ||
|
||
def __init__(self, vocab_size: int, max_seq_length: int, | ||
emb_dim: int, hparams: Dict[str, Any]): | ||
super().__init__() | ||
|
||
self.embedder = tx.modules.WordEmbedder( | ||
vocab_size=vocab_size, hparams=hparams['embedder']) | ||
self.classifier = tx.modules.Conv1DClassifier( | ||
in_channels=max_seq_length, | ||
in_features=emb_dim, hparams=hparams['classifier']) | ||
|
||
def forward(self, batch: tx.data.Batch) -> \ | ||
Tuple[torch.Tensor, torch.Tensor]: | ||
logits, pred = self.classifier( | ||
self.embedder(batch['sentence_text_ids'])) | ||
loss = F.cross_entropy(logits, batch['label'].long()) | ||
return pred, loss | ||
|
||
|
||
def main(): | ||
# Data | ||
train_data = tx.data.MultiAlignedData(config.train_data, device=device) | ||
val_data = tx.data.MultiAlignedData(config.val_data, device=device) | ||
test_data = tx.data.MultiAlignedData(config.test_data, device=device) | ||
data_iterator = tx.data.TrainTestDataIterator( | ||
train=train_data, val=val_data, test=test_data) | ||
|
||
hparams = { | ||
'embedder': config.emb, | ||
'classifier': config.clas | ||
} | ||
model = SentenceClassifier(vocab_size=train_data.vocab('sentence').size, | ||
max_seq_length=config.max_seq_length, | ||
emb_dim=config.emb_dim, | ||
hparams=hparams) | ||
model.to(device) | ||
train_op = tx.core.get_train_op(params=model.parameters(), | ||
hparams=config.opt) | ||
|
||
def _run_epoch(mode, epoch): | ||
|
||
step = 0 | ||
avg_rec = tx.utils.AverageRecorder() | ||
for batch in data_iterator: | ||
pred, loss = model(batch) | ||
if mode == "train": | ||
loss.backward() | ||
train_op() | ||
accu = tx.evals.accuracy(batch['label'], pred) | ||
step += 1 | ||
if step == 1 or step % 100 == 0: | ||
print(f"epoch: {epoch:2} step: {step:4} accu: {accu:.4f}") | ||
|
||
batch_size = batch['label'].size(0) | ||
avg_rec.add([accu], batch_size) | ||
|
||
return avg_rec.avg(0) | ||
|
||
best_val_accu = -1 | ||
for epoch in range(config.num_epochs): | ||
# Train | ||
data_iterator.switch_to_train_data() | ||
model.train() | ||
train_accu = _run_epoch("train", epoch) | ||
|
||
# Val | ||
data_iterator.switch_to_val_data() | ||
model.eval() | ||
val_accu = _run_epoch("val", epoch) | ||
print(f'epoch: {epoch:2} train accu: {train_accu:.4f} ' | ||
f'val accu: {val_accu:.4f}') | ||
|
||
# Test | ||
if val_accu > best_val_accu: | ||
best_val_accu = val_accu | ||
data_iterator.switch_to_test_data() | ||
model.eval() | ||
test_accu = _run_epoch("test", epoch) | ||
print(f'test accu: {test_accu:.4f}') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Sentence convolutional classifier config. | ||
This is (approximately) the config of the paper: | ||
(Kim) Convolutional Neural Networks for Sentence Classification | ||
https://arxiv.org/pdf/1408.5882.pdf | ||
""" | ||
|
||
import copy | ||
|
||
num_epochs = 15 | ||
max_seq_length = 56 | ||
emb_dim = 300 | ||
|
||
train_data = { | ||
"batch_size": 50, | ||
"datasets": [ | ||
{ | ||
"files": "./data/sst2.train.sentences.txt", | ||
"vocab_file": "./data/sst2.vocab", | ||
# Discards samples with length > 56 | ||
"max_seq_length": max_seq_length, | ||
"length_filter_mode": "discard", | ||
# Do not append BOS/EOS tokens to the sentences | ||
"bos_token": "", | ||
"eos_token": "", | ||
"data_name": "sentence" | ||
}, | ||
{ | ||
"files": "./data/sst2.train.labels.txt", | ||
"data_type": "int", | ||
"data_name": "label" | ||
} | ||
] | ||
} | ||
# The val and test data have the same config with the train data, except | ||
# for the file names | ||
val_data = copy.deepcopy(train_data) | ||
val_data["datasets"][0]["files"] = "./data/sst2.dev.sentences.txt" | ||
val_data["datasets"][1]["files"] = "./data/sst2.dev.labels.txt" | ||
test_data = copy.deepcopy(train_data) | ||
test_data["datasets"][0]["files"] = "./data/sst2.test.sentences.txt" | ||
test_data["datasets"][1]["files"] = "./data/sst2.test.labels.txt" | ||
|
||
# Word embedding | ||
emb = { | ||
"dim": emb_dim | ||
} | ||
|
||
# Classifier | ||
clas = { | ||
"num_conv_layers": 1, | ||
"out_channels": 100, | ||
"kernel_size": [3, 4, 5], | ||
"conv_activation": "ReLU", | ||
"pooling": "MaxPool1d", | ||
"num_dense_layers": 0, | ||
"dropout_conv": [1], | ||
"dropout_rate": 0.5, | ||
"num_classes": 2 | ||
} | ||
|
||
# Optimization | ||
# Just use the default config, e.g., Adam Optimizer | ||
opt = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2019 The Texar Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Preparing the SST2 dataset. | ||
""" | ||
|
||
from typing import Tuple | ||
|
||
import argparse | ||
import os | ||
import re | ||
|
||
import texar.torch as tx | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--data-path', type=str, default='./data', | ||
help="E.g., ./data/sst2.train.sentences.txt. If not exists, the directory " | ||
"will be created and SST raw data will be downloaded.") | ||
args = parser.parse_args() | ||
|
||
|
||
def clean_sst_text(text: str) -> str: | ||
"""Cleans tokens in the SST data, which has already been tokenized. | ||
""" | ||
text = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", text) | ||
text = re.sub(r"\s{2,}", " ", text) | ||
return text.strip().lower() | ||
|
||
|
||
def transform_raw_sst(data_path: str, raw_filename: str, new_filename: str) -> \ | ||
Tuple[str, str]: | ||
"""Transforms the raw data format to a new format. | ||
""" | ||
fout_x_name = os.path.join(data_path, new_filename + '.sentences.txt') | ||
fout_x = open(fout_x_name, 'w', encoding='utf-8') | ||
fout_y_name = os.path.join(data_path, new_filename + '.labels.txt') | ||
fout_y = open(fout_y_name, 'w', encoding='utf-8') | ||
|
||
fin_name = os.path.join(data_path, raw_filename) | ||
with open(fin_name, 'r', encoding='utf-8') as fin: | ||
for line in fin: | ||
parts = line.strip().split() | ||
label = parts[0] | ||
sent = ' '.join(parts[1:]) | ||
sent = clean_sst_text(sent) | ||
fout_x.write(sent + '\n') | ||
fout_y.write(label + '\n') | ||
|
||
return fout_x_name, fout_y_name | ||
|
||
|
||
def prepare_data(): | ||
"""Preprocesses SST2 data. | ||
""" | ||
train_path = os.path.join(args.data_path, "sst.train.sentences.txt") | ||
if not os.path.exists(train_path): | ||
url = ('https://raw.githubusercontent.com/ZhitingHu/' | ||
'logicnn/master/data/raw/') | ||
files = ['stsa.binary.phrases.train', 'stsa.binary.dev', | ||
'stsa.binary.test'] | ||
for fn in files: | ||
tx.data.maybe_download(url + fn, args.data_path, extract=True) | ||
|
||
fn_train, _ = transform_raw_sst( | ||
args.data_path, 'stsa.binary.phrases.train', 'sst2.train') | ||
transform_raw_sst(args.data_path, 'stsa.binary.dev', 'sst2.dev') | ||
transform_raw_sst(args.data_path, 'stsa.binary.test', 'sst2.test') | ||
|
||
vocab = tx.data.make_vocab(fn_train) | ||
fn_vocab = os.path.join(args.data_path, 'sst2.vocab') | ||
with open(fn_vocab, 'w', encoding='utf-8') as f_vocab: | ||
for v in vocab: | ||
f_vocab.write(v + '\n') | ||
|
||
print('Preprocessing done: {}'.format(args.data_path)) | ||
|
||
|
||
def main(): | ||
"""Entrypoint. | ||
""" | ||
prepare_data() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |