Skip to content

Commit

Permalink
Update model
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Mar 2, 2018
1 parent e591b85 commit 990fe86
Showing 1 changed file with 14 additions and 51 deletions.
65 changes: 14 additions & 51 deletions anago/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@
from keras.layers import Dense, LSTM, Bidirectional, Embedding, Input, Dropout, Lambda, Activation
from keras.layers.merge import Concatenate
from keras.models import Model
from sklearn.model_selection import train_test_split

from anago.layers import ChainCRF
from anago.reader import batch_iter
from anago.callbacks import get_callbacks


class BaseModel(object):

def __init__(self):
self.model = None

def predict(self, X, *args, **kwargs):
y_pred = self.model.predict(X, batch_size=1)
return y_pred

def score(self, X, y):
score = self.model.evaluate(X, y, batch_size=1)
return score

def save(self, filepath):
self.model.save_weights(filepath)

def load(self, filepath):
self.model.load_weights(filepath=filepath)

def __getattr__(self, name):
return getattr(self.model, name)


class BiLSTMCRF(BaseModel):
"""A Keras implementation of BiLSTM-CRF for sequence labeling.
Expand All @@ -38,8 +37,7 @@ class BiLSTMCRF(BaseModel):

def __init__(self, char_emb_size=25, word_emb_size=100, char_lstm_units=25,
word_lstm_units=100, dropout=0.5, char_feature=True, use_crf=True,
word_vocab_size=10000, char_vocab_size=100, embeddings=None, ntags=None,
batch_size=32, optimizer='adam', max_epoch=15, early_stopping=False):
word_vocab_size=10000, char_vocab_size=100, embeddings=None, ntags=None):
self._char_emb_size = char_emb_size
self._word_emb_size = word_emb_size
self._char_lstm_units = char_lstm_units
Expand All @@ -51,12 +49,9 @@ def __init__(self, char_emb_size=25, word_emb_size=100, char_lstm_units=25,
self._use_crf = use_crf
self._embeddings = embeddings
self._ntags = ntags
self._batch_size = batch_size
self._optimizer = optimizer
self._max_epoch = max_epoch
self._early_stopping = early_stopping
self._loss = None

def _build_model(self):
def build_model(self):
# build word embedding
word_ids = Input(batch_shape=(None, None), dtype='int32')
if self._embeddings is None:
Expand Down Expand Up @@ -97,8 +92,9 @@ def _build_model(self):
x = Dense(self._ntags)(x)

if self._use_crf:
self.crf = ChainCRF()
pred = self.crf(x)
crf = ChainCRF()
self._loss = crf.loss
pred = crf(x)
else:
pred = Activation('softmax')(x)

Expand All @@ -108,38 +104,5 @@ def _build_model(self):
else:
self.model = Model(inputs=[word_ids, sequence_lengths], outputs=[pred])

def fit(self, X, y):
x_train, x_valid, y_train, y_valid = train_test_split(X, y, test_size=0.3, random_state=42)
# Prepare training and validation data(steps, generator)
train_steps, train_batches = batch_iter(x_train,
y_train,
self._batch_size,
preprocessor=self.preprocessor)
valid_steps, valid_batches = batch_iter(x_valid,
y_valid,
self._batch_size,
preprocessor=self.preprocessor)

self._build_model()

if self._use_crf:
self.model.compile(loss=self.crf.loss,
optimizer=self._optimizer)
else:
self.model.compile(loss='categorical_crossentropy',
optimizer=self._optimizer)

# Prepare callbacks
"""
callbacks = get_callbacks(log_dir=self.checkpoint_path,
tensorboard=self.tensorboard,
eary_stopping=self._early_stopping,
valid=(valid_steps, valid_batches, self.preprocessor))
"""
callbacks = []

# Train the model
self.model.fit_generator(generator=train_batches,
steps_per_epoch=train_steps,
epochs=self._max_epoch,
callbacks=callbacks)
def get_loss(self):
return self._loss

0 comments on commit 990fe86

Please sign in to comment.