Skip to content

Commit

Permalink
Update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Nov 18, 2017
1 parent 62392b2 commit c2e63ae
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 119 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import anago
from anago.data.reader import load_data_and_labels, load_word_embeddings
from anago.data.preprocess import prepare_preprocessor
from anago.config import ModelConfig, TrainingConfig
from anago.models import SeqLabeling
```
They include loading modules, a preprocessor and configs.

Expand All @@ -89,7 +90,7 @@ x_valid, y_valid = load_data_and_labels(valid_path)
x_test, y_test = load_data_and_labels(test_path)
```

After reading the data, prepare preprocessor and pre-trained word embeddings:
After reading the data, build preprocessor and load pre-trained word embeddings:
```python
p = prepare_preprocessor(x_train, y_train)
embeddings = load_word_embeddings(p.vocab_word, embedding_path, model_config.word_embedding_size)
Expand All @@ -104,9 +105,14 @@ Now we are ready for training :)
Let's train a model. For training a model, we can use ***Trainer***.
Trainer manages everything about training.
Prepare an instance of Trainer class and give train data and valid data to train method:
```
trainer = anago.Trainer(model_config, training_config, checkpoint_path=LOG_ROOT, save_path=SAVE_ROOT,
preprocessor=p, embeddings=embeddings)
```python
model = SeqLabeling(model_config, embeddings, len(p.vocab_tag))
trainer = anago.Trainer(model,
training_config,
checkpoint_path=LOG_ROOT,
save_path=SAVE_ROOT,
preprocessor=p,
embeddings=embeddings)
trainer.train(x_train, y_train, x_valid, y_valid)
```

Expand Down
3 changes: 0 additions & 3 deletions anago/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ def __init__(self):
self.vocab_size = None
self.char_vocab_size = None

# Batch size.
self.batch_size = 32

# Scale used to initialize model variables.
self.initializer_scale = 0.08

Expand Down
7 changes: 5 additions & 2 deletions anago/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def __init__(self,
def eval(self, x_test, y_test):

# Prepare test data(steps, generator)
train_steps, train_batches = batch_iter(
list(zip(x_test, y_test)), self.config.batch_size, preprocessor=self.preprocessor)
train_steps, train_batches = batch_iter(x_test,
y_test,
self.config.batch_size,
shuffle=False,
preprocessor=self.preprocessor)

# Build the model
model = SeqLabeling(self.config, ntags=len(self.preprocessor.vocab_tag))
Expand Down
69 changes: 0 additions & 69 deletions anago/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,72 +363,3 @@ def sparse_loss(*args):
return method(*args)

return {'ChainCRF': ClassWrapper, 'loss': loss, 'sparse_loss': sparse_loss}


class CRFLayer(Layer):

def __init__(self, **kwargs):
super(CRFLayer, self).__init__(**kwargs)
self.input_spec = [InputSpec(ndim=3)]

def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) == 3
return (input_shape[0], input_shape[1], input_shape[2])

def _fetch_mask(self):
mask = None
if self.inbound_nodes:
mask = self.inbound_nodes[0].input_masks[0]
return mask

def build(self, input_shape):
assert len(input_shape) == 3
n_classes = input_shape[2]
n_steps = input_shape[1]
assert n_steps is None or n_steps >= 2
self.transition_params = self.add_weight((n_classes, n_classes),
initializer='uniform',
name='transition')
self.input_spec = [InputSpec(dtype=K.floatx(),
shape=(None, n_steps, n_classes))]
self.built = True

def viterbi_decode(self, x, mask):
viterbi_sequences = []
transition_params = K.eval(self.transition_params)
#logits = tf.map_fn(lambda x: x[0][:x[1]], [x, mask])
#print(logits)
logits = x
sequences = tf.map_fn(lambda p: tf.contrib.crf.viterbi_decode(p, transition_params)[0], logits)
print(sequences)
for logit, sequence_length in zip(x, mask):
logit = logit[:sequence_length]
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(logit, transition_params)
viterbi_sequences += [viterbi_sequence]

return viterbi_sequences

def call(self, x, mask=[2,2]):
y_pred = self.viterbi_decode(x, mask)
nb_classes = self.input_spec[0].shape[2]
y_pred_one_hot = K.one_hot(y_pred, nb_classes)
return K.in_train_phase(x, y_pred_one_hot)

def loss(self, y_true, y_pred):
mask = self._fetch_mask()
#sequence_lengths = K.reshape(mask, (-1,))
sequence_lengths = mask
y_t = K.argmax(y_true, -1)
y_t = K.cast(y_t, tf.int32)
log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
y_pred, y_t, sequence_lengths, self.transition_params)
loss = tf.reduce_mean(-log_likelihood)

return loss

def get_config(self):
config = {
'transition_params': initializers.serialize(self.transition_params),
}
base_config = super(CRFLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
1 change: 0 additions & 1 deletion anago/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import tensorflow as tf
import keras.backend as K
from keras.layers import Dense, LSTM, Bidirectional, Embedding, Input, Dropout, Lambda
from keras.layers.merge import Concatenate
Expand Down
26 changes: 10 additions & 16 deletions anago/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

from anago.data.metrics import get_callbacks
from anago.data.reader import batch_iter
from anago.models import SeqLabeling


class Trainer(object):

def __init__(self,
model_config,
model,
training_config,
checkpoint_path='',
save_path='',
Expand All @@ -19,7 +18,7 @@ def __init__(self,
embeddings=None
):

self.model_config = model_config
self.model = model
self.training_config = training_config
self.checkpoint_path = checkpoint_path
self.save_path = save_path
Expand All @@ -35,23 +34,18 @@ def train(self, x_train, y_train, x_valid=None, y_valid=None):
valid_steps, valid_batches = batch_iter(
x_valid, y_valid, self.training_config.batch_size, preprocessor=self.preprocessor)

# Build the model
model = SeqLabeling(self.model_config, self.embeddings, len(self.preprocessor.vocab_tag))
model.compile(loss=model.crf.loss,
optimizer=Adam(lr=self.training_config.learning_rate),
)
self.model.compile(loss=self.model.crf.loss,
optimizer=Adam(lr=self.training_config.learning_rate),
)

# Prepare callbacks for training
# Prepare callbacks
callbacks = get_callbacks(log_dir=self.checkpoint_path,
tensorboard=self.tensorboard,
eary_stopping=self.training_config.early_stopping,
valid=(valid_steps, valid_batches, self.preprocessor))

# Train the model
model.fit_generator(generator=train_batches,
steps_per_epoch=train_steps,
epochs=self.training_config.max_epoch,
callbacks=callbacks)

# Save the model
model.save(os.path.join(self.save_path, 'model_weights.h5'))
self.model.fit_generator(generator=train_batches,
steps_per_epoch=train_steps,
epochs=self.training_config.max_epoch,
callbacks=callbacks)
29 changes: 20 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
backports.weakref==1.0rc1
bleach==1.5.0
h5py==2.7.0
boto==2.48.0
bz2file==0.98
certifi==2017.11.5
chardet==3.0.4
enum34==1.1.6
gensim==3.1.0
h5py==2.7.1
html5lib==0.9999999
Keras==2.0.5
Markdown==2.2.0
numpy==1.13.0
protobuf==3.3.0
idna==2.6
Keras==2.1.1
Markdown==2.6.9
numpy==1.13.3
protobuf==3.5.0
python-dateutil==2.6.0
pytz==2017.2
PyYAML==3.12
scikit-learn==0.18.2
scipy==0.19.1
six==1.10.0
tensorflow==1.2.0
requests==2.18.4
scikit-learn==0.19.1
scipy==1.0.0
six==1.11.0
smart-open==1.5.3
tensorflow==1.4.0
tensorflow-tensorboard==0.4.0rc3
Theano==0.9.0
urllib3==1.22
Werkzeug==0.12.2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
sys.exit()

required = [
'Keras>=2.0.5', 'h5py>=2.7.0', 'scikit-learn>0.18.2', 'numpy>=1.13.0', 'tensorflow>=1.2.0',
'Keras>=2.1.1', 'h5py>=2.7.1', 'scikit-learn>0.19.1', 'numpy>=1.13.3', 'tensorflow>=1.4.0',
]

setup(
Expand Down
43 changes: 29 additions & 14 deletions tests/train_test.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
import os
import unittest

import numpy as np

import anago
from anago.data.reader import load_data_and_labels, load_word_embeddings
from anago.data.preprocess import prepare_preprocessor
from anago.config import ModelConfig, TrainingConfig
from anago.models import SeqLabeling


get_path = lambda path: os.path.join(os.path.dirname(__file__), path)
DATA_ROOT = get_path('../data/conll2003/en/ner')
SAVE_ROOT = get_path('models') # trained model
LOG_ROOT = get_path('logs') # checkpoint, tensorboard
EMBEDDING_PATH = get_path('../data/glove.6B/glove.6B.100d.txt')


class TrainerTest(unittest.TestCase):

def test_train(self):
DATA_ROOT = os.path.join(os.path.dirname(__file__), '../data/conll2003/en/ner')
SAVE_ROOT = os.path.join(os.path.dirname(__file__), '../models') # trained model
LOG_ROOT = os.path.join(os.path.dirname(__file__), '../logs') # checkpoint, tensorboard
embedding_path = os.path.join(os.path.dirname(__file__), '../data/glove.6B/glove.6B.100d.txt')
@classmethod
def setUpClass(cls):
if not os.path.exists(LOG_ROOT):
os.mkdir(LOG_ROOT)

if not os.path.exists(SAVE_ROOT):
os.mkdir(SAVE_ROOT)

def test_train(self):
model_config = ModelConfig()
training_config = TrainingConfig()

train_path = os.path.join(DATA_ROOT, 'train.txt')
valid_path = os.path.join(DATA_ROOT, 'valid.txt')
test_path = os.path.join(DATA_ROOT, 'test.txt')
x_train, y_train = load_data_and_labels(train_path)
x_valid, y_valid = load_data_and_labels(valid_path)
x_test, y_test = load_data_and_labels(test_path)

p = prepare_preprocessor(np.r_[x_train, x_valid, x_test], y_train) # np.r_ is for vocabulary expansion.
p = prepare_preprocessor(x_train, y_train)
p.save(os.path.join(SAVE_ROOT, 'preprocessor.pkl'))
embeddings = load_word_embeddings(p.vocab_word, embedding_path, model_config.word_embedding_size)
embeddings = load_word_embeddings(p.vocab_word, EMBEDDING_PATH, model_config.word_embedding_size)
model_config.char_vocab_size = len(p.vocab_char)

trainer = anago.Trainer(model_config, training_config, checkpoint_path=LOG_ROOT, save_path=SAVE_ROOT,
preprocessor=p, embeddings=embeddings)
trainer.train(x_train, y_train, x_test, y_test)
model = SeqLabeling(model_config, embeddings, len(p.vocab_tag))

trainer = anago.Trainer(model,
training_config,
checkpoint_path=LOG_ROOT,
save_path=SAVE_ROOT,
preprocessor=p,
embeddings=embeddings)
trainer.train(x_train, y_train, x_valid, y_valid)

model.save(os.path.join(SAVE_ROOT, 'model_weights.h5'))

0 comments on commit c2e63ae

Please sign in to comment.