In [1]:
from keras.models import Sequential, Input
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional
from keras_contrib.layers.crf import CRF
from keras_contrib.utils import save_load_utils
from keras_contrib.metrics import crf_accuracy
from keras_contrib.losses import crf_loss
from keras.utils import to_categorical
import pandas as pd
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [2]:
import os
BERT_BASE = os.path.join(os.getcwd(), 'bert/bert_model/uncased_L-12_H-768_A-12')

In [3]:
class LSTMmodel:
    def __init__(self, input_length, para_emb_dim, num_tags, hidden_dim=200, dropout=0.5):
        self.num_tags = num_tags
        self.model = Sequential()
        self.model.add(Bidirectional(LSTM(hidden_dim, return_sequences=True), input_shape=(input_length, para_emb_dim)))
        self.model.add(Dropout(dropout))
        # self.model.add(Bidirectional(LSTM(hidden_dim, return_sequences=True), input_shape=(input_length, para_emb_dim)))
        # self.model.add(Dropout(dropout))
        self.model.add(TimeDistributed(Dense(self.num_tags)))
        crf = CRF(self.num_tags)
        self.model.add(crf)
        self.model.compile('rmsprop', loss=crf_loss, metrics=[crf_accuracy])
    
    def save_model(self, filepath):
        save_load_utils.save_all_weights(self.model, filepath)
    
    def restore_model(self, filepath):
        save_load_utils.load_all_weights(self.model, filepath)
        
    def train(self, trainX, trainY, batch_size=32, epochs=10, validation_split=0.1, verbose=1):
        return self.model.fit(trainX, np.array(trainY), batch_size=batch_size, epochs=epochs, 
                             validation_split=validation_split, verbose=verbose)
        

In [4]:
from Dataprocessor import Dataprocessor

filelist = [('data/%d.json' % i) for i in range(500)]
processor = Dataprocessor()
train_texts, train_tags, train_rawtags = processor.load_data(filelist)

loading 0.json...
loading 1.json...
loading 2.json...
loading 3.json...
loading 4.json...
loading 5.json...
loading 6.json...
loading 7.json...
loading 8.json...
loading 9.json...
loading 10.json...
loading 11.json...
loading 12.json...
loading 13.json...
loading 14.json...
loading 15.json...
loading 16.json...
loading 17.json...
loading 18.json...
loading 19.json...
loading 20.json...
loading 21.json...
loading 22.json...
loading 23.json...
loading 24.json...
loading 25.json...
loading 26.json...
loading 27.json...
loading 28.json...
loading 29.json...
loading 30.json...
loading 31.json...
loading 32.json...
loading 33.json...
loading 34.json...
loading 35.json...
loading 36.json...
loading 37.json...
loading 38.json...
loading 39.json...
loading 40.json...
loading 41.json...
loading 42.json...
loading 43.json...
loading 44.json...
loading 45.json...
loading 46.json...
loading 47.json...
loading 48.json...
loading 49.json...
loading 50.json...
loading 51.json...
loading 52.json...
loa

loading 416.json...
loading 417.json...
loading 418.json...
loading 419.json...
loading 420.json...
loading 421.json...
loading 422.json...
loading 423.json...
loading 424.json...
loading 425.json...
loading 426.json...
loading 427.json...
loading 428.json...
loading 429.json...
loading 430.json...
loading 431.json...
loading 432.json...
loading 433.json...
loading 434.json...
loading 435.json...
loading 436.json...
loading 437.json...
loading 438.json...
loading 439.json...
loading 440.json...
loading 441.json...
loading 442.json...
loading 443.json...
loading 444.json...
loading 445.json...
loading 446.json...
loading 447.json...
loading 448.json...
loading 449.json...
loading 450.json...
loading 451.json...
loading 452.json...
loading 453.json...
loading 454.json...
loading 455.json...
loading 456.json...
loading 457.json...
loading 458.json...
loading 459.json...
loading 460.json...
loading 461.json...
loading 462.json...
loading 463.json...
loading 464.json...
loading 465.json...


In [7]:
from bert_utils import get_all_features

bert_config_file = os.path.join(BERT_BASE, 'bert_config.json')
vocab_file = os.path.join(BERT_BASE, 'vocab.txt')
bert_checkpoint = os.path.join(BERT_BASE, 'bert_model.ckpt')
    
feature = get_all_features(train_texts[0:2000], bert_config_file, vocab_file, bert_checkpoint)
print(len(feature))


Total 956 paragraphs


InternalError: Dst tensor is not initialized.
	 [[{{node checkpoint_initializer_162/_215}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_816_checkpoint_initializer_162", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

In [None]:
feature += get_all_features(train_texts[2000:6000], bert_config_file, vocab_file, bert_checkpoint)
print(len(feature))

In [None]:
INPUT_LENGTH = 100
PARAGRAPH_EMB_DIM = 768
NUM_TAGS = 12

model = LSTMmodel(INPUT_LENGTH, PARAGRAPH_EMB_DIM, NUM_TAGS)
model.model.summary()

In [None]:
# load data
import numpy as np

tags = train_tags[0:6000]
X, rawY = [], [] # X is 3D: article, paragraph, embedding; Y is 2D: article, paragraph
for f, t in zip(feature, tags):
    while len(f) < INPUT_LENGTH:
        f.append(np.zeros(PARAGRAPH_EMB_DIM))
        t.append(0)
    f = f[0:INPUT_LENGTH]
    t = t[0:INPUT_LENGTH]
    X.append(f)
    rawY.append(t)
    
Y = [to_categorical(y, num_classes=NUM_TAGS) for y in rawY] # Y is now 3D

data_size = len(X)
train_size = int(data_size * 0.9)
trainX, trainY = X[:train_size], Y[:train_size]
testX, testY = X[train_size:], Y[train_size:]

In [None]:
# train
history = model.model.fit(np.array(trainX), np.array(trainY), batch_size=32, epochs=10, validation_split=0.1)

In [9]:
# plot
# plt.style.use("ggplot")
# plt.figure(figsize=(12,12))
# plt.plot(hist["acc"])
# plt.plot(hist["val_acc"])
# plt.show()

NameError: name 'history' is not defined

<Figure size 864x864 with 0 Axes>

In [None]:
# Predict on test
test_pred = model.model.predict(np.array(testX), verbose=1)

In [None]:
truecnt = 0
falsecnt = 0
for (i, pred) in enumerate(test_pred):
    for j, p in enumerate(pred):
        if np.argmax(testY[i][j]) != 0:
            if np.argmax(p) == np.argmax(testY[i][j]):
                truecnt += 1
            else:
                falsecnt += 1
print(truecnt, falsecnt, truecnt/(truecnt+falsecnt))