In [1]:
import numpy as np
from sklearn.model_selection import ShuffleSplit
from data_utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
from data_utils import Evaluator
from models import build_lstm_crf_model
from gensim.models import Word2Vec

Using TensorFlow backend.


In [2]:
data_dir = 'brat/'
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])

In [3]:
docs = Documents(data_dir=data_dir)
rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2018)
train_doc_ids, test_doc_ids = next(rs.split(docs))
train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]

In [12]:
test_doc_ids

array([59,  7, 85, 41, 94, 33, 17, 56, 52,  2, 77, 37,  5, 62, 73, 95, 51,
       96, 66, 43])

In [13]:
train_doc_ids

array([93, 23, 11, 65, 67, 18, 32, 30, 74, 14, 48, 57, 80, 15, 27,  8, 24,
       46, 55, 39, 61, 89, 86, 54, 50, 12,  1, 19,  4, 16,  3, 34, 10, 72,
       35, 82, 36, 58, 76, 49, 64, 68, 44, 63, 13, 70, 45, 78, 81, 83, 69,
       84, 29, 79, 71, 40, 53, 38, 26,  0, 42, 92, 22, 31, 60, 90, 88, 47,
       75, 87, 91, 25,  6, 20, 28, 21,  9])

In [4]:
num_cates = max(ent2idx.values()) + 1
sent_len = 64
vocab_size = 3000
emb_size = 100
sent_pad = 10
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)
train_sents = sent_extrator(train_docs)
test_sents = sent_extrator(test_docs)
train_data = Dataset(train_sents, cate2idx=ent2idx)
train_data.build_vocab_dict(vocab_size=vocab_size)
test_data = Dataset(test_sents, word2idx=train_data.word2idx, cate2idx=ent2idx)
vocab_size = len(train_data.word2idx)

In [5]:
w2v_train_sents = []
for doc in docs:
    w2v_train_sents.append(list(doc.text))
    
    
w2v_model = Word2Vec(w2v_train_sents, size=emb_size)
w2v_embeddings = np.zeros((vocab_size, emb_size))
for char, char_idx in train_data.word2idx.items():
    if char in w2v_model.wv:
        w2v_embeddings[char_idx] = w2v_model.wv[char]

In [6]:
seq_len = sent_len + 2 * sent_pad
model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size, 
                             model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.summary()

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 84)                0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 84, 100)           15100     
_________________________________________________________________
bidirectional_1 (Bidirection (None, 84, 512)           731136    
_________________________________________________________________
crf_1 (CRF)                  (None, 84, 40)            22200     
Total params: 768,436
Trainable params: 753,336
Non-trainable params: 15,100
_________________________________________________________________


In [7]:
train_X, train_y = train_data[:]
print('train_X.shape', train_X.shape)
print('train_y.shape', train_y.shape)

train_X.shape (50146, 84)
train_y.shape (50146, 84, 1)


In [14]:
train_X

array([[ 1,  1,  1, ..., 41,  1,  4],
       [ 9,  2, 20, ..., 14,  7, 24],
       [ 7,  8,  7, ...,  1,  6,  1],
       ...,
       [10, 19,  1, ...,  5,  7,  9],
       [ 1, 52, 15, ...,  5,  4, 55],
       [ 1, 16,  9, ...,  1,  1,  1]])

In [9]:
model.fit(train_X,train_y, batch_size=64, epochs=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.callbacks.History at 0x7f1b6042f190>

In [10]:
test_X, _ = test_data[:]
preds = model.predict(test_X, batch_size=64, verbose=True)
pred_docs = make_predictions(preds, test_data, sent_pad, docs, idx2ent)



In [11]:
f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
print('f_score: ', f_score)
print('precision: ', precision)
print('recall: ', recall)

f_score:  0.8135261802042272
precision:  0.899577856333871
recall:  0.7425002124585706


In [12]:
sample_doc_id = list(pred_docs.keys())[0]
test_docs[sample_doc_id]

<data_utils.data_utils.Document at 0x7f1ad42ff190>

In [13]:
pred_docs[sample_doc_id]

<data_utils.data_utils.Document at 0x7f1ad40d9750>

In [None]:
ex = {'text': "123", 'ents': "321", 'title': None}