In [21]:
from keras.models import Sequential, Input
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional
from keras_contrib.layers.crf import CRF
from keras_contrib.utils import save_load_utils
from keras_contrib.metrics import crf_accuracy
from keras_contrib.losses import crf_loss

In [22]:
class LSTMmodel:
    def __init__(self, input_length, para_emb_dim, num_tags, hidden_dim=200, dropout=0.5):
        self.num_tags = num_tags
        self.model = Sequential()
        self.model.add(Bidirectional(LSTM(hidden_dim, return_sequences=True), input_shape=(input_length, para_emb_dim)))
        self.model.add(Dropout(dropout))
        # self.model.add(Bidirectional(LSTM(hidden_dim, return_sequences=True), input_shape=(input_length, para_emb_dim)))
        # self.model.add(Dropout(dropout))
        self.model.add(TimeDistributed(Dense(self.num_tags)))
        crf = CRF(self.num_tags)
        self.model.add(crf)
        self.model.compile('rmsprop', loss=losses.crf_loss, metrics=[crf_accuracy])
    
    def save_model(self, filepath):
        save_load_utils.save_all_weights(self.model, filepath)
    
    def restore_model(self, filepath):
        save_load_utils.load_all_weights(self.model, filepath)
        
    def train(self, trainX, trainY, batch_size=32, epochs=10, validation_split=0.1, verbose=1):
        return self.model.fit(trainX, np.array(trainY), batch_size=batch_size, epochs=epochs, 
                             validation_split=validation_split, verbose=verbose)
        

In [23]:
model = LSTMmodel(300, 300, 4)

In [24]:
model.model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bidirectional_6 (Bidirection (None, 300, 400)          801600    
_________________________________________________________________
dropout_6 (Dropout)          (None, 300, 400)          0         
_________________________________________________________________
time_distributed_6 (TimeDist (None, 300, 4)            1604      
_________________________________________________________________
crf_6 (CRF)                  (None, 300, 4)            44        
Total params: 803,248
Trainable params: 803,248
Non-trainable params: 0
_________________________________________________________________
