Skip to content

Commit

Permalink
Update model
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Jun 19, 2018
1 parent e2b03db commit 07a6cab
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 52 deletions.
54 changes: 23 additions & 31 deletions anago/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,22 @@
from anago.layers import CRF


class BaseModel(object):
def save_model(model, weights_file, params_file):
with open(params_file, 'w') as f:
params = model.to_json()
json.dump(json.loads(params), f, sort_keys=True, indent=4)
model.save_weights(weights_file)

def __init__(self):
self.model = None

def save(self, weights_file, params_file):
with open(params_file, 'w') as f:
params = self.model.to_json()
json.dump(json.loads(params), f, sort_keys=True, indent=4)
self.model.save_weights(weights_file)
def load_model(weights_file, params_file):
with open(params_file) as f:
model = model_from_json(f.read(), custom_objects={'CRF': CRF})
model.load_weights(weights_file)

@classmethod
def load(cls, weights_file, params_file):
with open(params_file) as f:
model = model_from_json(f.read(), custom_objects={'CRF': CRF})
model.load_weights(weights_file)
return model

return model

def __getattr__(self, name):
return getattr(self.model, name)


class BiLSTMCRF(BaseModel):
class BiLSTMCRF(object):
"""A Keras implementation of BiLSTM-CRF for sequence labeling.
References
Expand Down Expand Up @@ -85,30 +77,31 @@ def __init__(self,
self._use_crf = use_crf
self._embeddings = embeddings
self._num_labels = num_labels
self._loss = None

def build(self):
# build word embedding
word_ids = Input(batch_shape=(None, None), dtype='int32')
word_ids = Input(batch_shape=(None, None), dtype='int32', name='word_input')
inputs = [word_ids]
if self._embeddings is None:
word_embeddings = Embedding(input_dim=self._word_vocab_size,
output_dim=self._word_embedding_dim,
mask_zero=True)(word_ids)
mask_zero=True,
name='word_embedding')(word_ids)
else:
word_embeddings = Embedding(input_dim=self._embeddings.shape[0],
output_dim=self._embeddings.shape[1],
mask_zero=True,
weights=[self._embeddings])(word_ids)
weights=[self._embeddings],
name='word_embedding')(word_ids)

# build character based word embedding
if self._use_char:
char_ids = Input(batch_shape=(None, None, None), dtype='int32')
char_ids = Input(batch_shape=(None, None, None), dtype='int32', name='char_input')
inputs.append(char_ids)
char_embeddings = Embedding(input_dim=self._char_vocab_size,
output_dim=self._char_embedding_dim,
mask_zero=True
)(char_ids)
mask_zero=True,
name='char_embedding')(char_ids)
char_embeddings = TimeDistributed(Bidirectional(LSTM(self._char_lstm_size)))(char_embeddings)
word_embeddings = Concatenate()([word_embeddings, char_embeddings])

Expand All @@ -118,13 +111,12 @@ def build(self):

if self._use_crf:
crf = CRF(self._num_labels, sparse_target=False)
self._loss = crf.loss_function
loss = crf.loss_function
pred = crf(z)
else:
self._loss = 'categorical_crossentropy'
loss = 'categorical_crossentropy'
pred = Dense(self._num_labels, activation='softmax')(z)

self.model = Model(inputs=inputs, outputs=pred)
model = Model(inputs=inputs, outputs=pred)

def get_loss(self):
return self._loss
return model, loss
10 changes: 5 additions & 5 deletions anago/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from seqeval.metrics import f1_score

from anago.models import BiLSTMCRF
from anago.models import BiLSTMCRF, save_model, load_model
from anago.preprocessing import IndexTransformer
from anago.tagger import Tagger
from anago.trainer import Trainer
Expand Down Expand Up @@ -77,8 +77,8 @@ def fit(self, x_train, y_train, x_valid=None, y_valid=None,
embeddings=embeddings,
use_char=self.use_char,
use_crf=self.use_crf)
model.build()
model.compile(loss=model.get_loss(), optimizer=self.optimizer)
model, loss = model.build()
model.compile(loss=loss, optimizer=self.optimizer)

trainer = Trainer(model, preprocessor=p)
trainer.train(x_train, y_train, x_valid, y_valid,
Expand Down Expand Up @@ -131,12 +131,12 @@ def analyze(self, text, tokenizer=str.split):

def save(self, weights_file, params_file, preprocessor_file):
self.p.save(preprocessor_file)
self.model.save(weights_file, params_file)
save_model(self.model, weights_file, params_file)

@classmethod
def load(cls, weights_file, params_file, preprocessor_file):
self = cls()
self.p = IndexTransformer.load(preprocessor_file)
self.model = BiLSTMCRF.load(weights_file, params_file)
self.model = load_model(weights_file, params_file)

return self
8 changes: 4 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import unittest

from anago.models import BiLSTMCRF
from anago.models import BiLSTMCRF, load_model, save_model


class TestModel(unittest.TestCase):
Expand Down Expand Up @@ -56,14 +56,14 @@ def test_save_and_load(self):
model = BiLSTMCRF(char_vocab_size=char_vocab_size,
word_vocab_size=word_vocab_size,
num_labels=num_labels)
model.build()
model, loss = model.build()

self.assertFalse(os.path.exists(self.weights_file))
self.assertFalse(os.path.exists(self.params_file))

model.save(self.weights_file, self.params_file)
save_model(model, self.weights_file, self.params_file)

self.assertTrue(os.path.exists(self.weights_file))
self.assertTrue(os.path.exists(self.params_file))

model = BiLSTMCRF.load(self.weights_file, self.params_file)
model = load_model(self.weights_file, self.params_file)
4 changes: 2 additions & 2 deletions tests/test_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

import anago
from anago.models import BiLSTMCRF
from anago.models import load_model
from anago.preprocessing import IndexTransformer

DATA_ROOT = os.path.join(os.path.dirname(__file__), '../data/conll2003/en/ner')
Expand All @@ -24,7 +24,7 @@ def setUpClass(cls):
p = IndexTransformer.load(preprocessor_file)

# Load the model.
model = BiLSTMCRF.load(weights_file, params_file)
model = load_model(weights_file, params_file)

# Build a tagger
cls.tagger = anago.Tagger(model, preprocessor=p)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from anago.utils import load_data_and_labels
from anago.models import BiLSTMCRF
from anago.models import BiLSTMCRF, save_model
from anago.preprocessing import IndexTransformer
from anago.trainer import Trainer

Expand Down Expand Up @@ -42,8 +42,8 @@ def setUp(self):
self.model = BiLSTMCRF(char_vocab_size=self.p.char_vocab_size,
word_vocab_size=self.p.word_vocab_size,
num_labels=self.p.label_size)
self.model.build()
self.model.compile(loss=self.model.get_loss(), optimizer='adam')
self.model, loss = self.model.build()
self.model.compile(loss=loss, optimizer='adam')

def test_train(self):
trainer = Trainer(self.model, preprocessor=self.p)
Expand All @@ -59,8 +59,8 @@ def test_train_no_crf(self):
word_vocab_size=self.p.word_vocab_size,
num_labels=self.p.label_size,
use_crf=False)
model.build()
model.compile(loss=model.get_loss(), optimizer='adam')
model, loss = model.build()
model.compile(loss=loss, optimizer='adam')
trainer = Trainer(model, preprocessor=self.p)
trainer.train(self.x_train, self.y_train,
x_valid=self.x_valid, y_valid=self.y_valid)
Expand All @@ -72,8 +72,8 @@ def test_train_no_character(self):
num_labels=p.label_size,
use_crf=False,
use_char=False)
model.build()
model.compile(loss=model.get_loss(), optimizer='adam')
model, loss = model.build()
model.compile(loss=loss, optimizer='adam')
trainer = Trainer(model, preprocessor=p)
trainer.train(self.x_train, self.y_train,
x_valid=self.x_valid, y_valid=self.y_valid)
Expand All @@ -84,5 +84,5 @@ def test_save(self):
trainer.train(self.x_train, self.y_train)

# Save the model.
self.model.save(self.weights_file, self.params_file)
save_model(self.model, self.weights_file, self.params_file)
self.p.save(self.preprocessor_file)
4 changes: 2 additions & 2 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_train_callbacks(self):
preprocessor_file = os.path.join(SAVE_ROOT, 'preprocessor.pickle')

log_dir = os.path.join(os.path.dirname(__file__), 'logs')
file_name = '_'.join(['model_weights', '{epoch:02d}', '{f1:2.4f}']) + '.h5'
file_name = '_'.join(['weights', '{epoch:02d}', '{f1:2.4f}']) + '.h5'
callback = ModelCheckpoint(os.path.join(log_dir, file_name),
monitor='f1',
save_weights_only=True)
Expand All @@ -113,5 +113,5 @@ def test_train_callbacks(self):
vocab.add(word)
model = anago.Sequence(initial_vocab=vocab, embeddings=self.embeddings)
model.fit(self.x_train, self.y_train, self.x_test, self.y_test,
epochs=30, callbacks=[callback])
epochs=1, callbacks=[callback])
model.save(weights_file, params_file, preprocessor_file)

0 comments on commit 07a6cab

Please sign in to comment.