In [None]:
import json
import numpy as np
import random
import datetime

import tensorflow as tf

from IPython.core.display import display, HTML

In [None]:
batch_size = 5
vec_dim = 100
state_dim = 50
sent_len = 20

# Load word embedded vectors

In [None]:
def load_vec():
    dict_vec = {}
    with open('tokens.vec', 'r', encoding='utf-8') as f:
        f.readline()
        for line in f:
            line = line.split(' ')
            token = line[0]
            vec = line[1:-1]
            dict_vec[token] = [float(i) for i in vec]
    return dict_vec

def get_vec(k):
    ak = dict_vec.keys()
    if k in ak:
        return dict_vec[k]
    else:
        return [0.0]*vec_dim

In [None]:
dict_vec = load_vec()

# Load data

In [None]:
with open('docs.txt', 'r', encoding='utf-8') as f:
    docs = json.load(f)

In [None]:
with open('sents.txt', 'r', encoding='utf-8') as f:
    sents = json.load(f)

In [None]:
docs2 = docs[:550]

In [None]:
sents2 = []
for s in sents:
    if s['labels'].count(1) > 0:
        sents2.append(s)

In [None]:
def get_sents(doc):
    sents = []
    tokens = []
    labels = []

    for t, l in list(zip(doc['text'], doc['labels'])):
        if t in ['，', '。', '？', '！']:

            sents.append({
                'tokens': tokens,
                'labels': labels 
            })

            tokens = []
            labels = []
        else:
            tokens.append(t)
            labels.append(l)
            
    return sents

In [None]:
sents3 = []
for d in docs2:
    sents3 += get_sents(d)

In [None]:
def get_data(sents, sent_len, vec_dim):
    x = []
    y = []

    for s in sents:
        tokens = s['tokens']
        labels = s['labels']

        sent = []
        lb = []
        for t in range(sent_len):
            if t <= len(tokens) - 1:
                sent.append(get_vec(tokens[t]))
                lb.append(float(labels[t]))
            else:
                sent.append([0.0]*vec_dim)
                lb.append(0.0)
        x.append(sent)
        y.append(lb)

    return np.array(x), np.array(y)

In [None]:
x, y = get_data(sents2, sent_len, vec_dim)
x_train = x
y_train = y

# Build model

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(sent_len, vec_dim)))
model.add(tf.keras.layers.LSTM(state_dim, return_sequences=True))
model.add(tf.keras.layers.Dense(1, activation="sigmoid"))
model.add(tf.keras.layers.Lambda(lambda x: tf.squeeze(x)))
model.summary()

# Train

In [None]:
%load_ext tensorboard

In [None]:
log_dir="logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
loss = lambda label, outputs: tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(outputs, label))))

In [None]:
optimizer = tf.keras.optimizers.Adadelta(learning_rate=1.0, rho=0.95, epsilon=1e-06)

In [None]:
model.compile(loss=loss, optimizer=optimizer, metrics=["mse"])

In [None]:
model.fit(x_train, y_train, batch_size=batch_size, epochs=50, validation_split=0.1, callbacks=[tensorboard_callback])

In [None]:
%tensorboard --logdir logs

# Visulization

In [None]:
def build_sent_html(text, labels):
    spans = []
    for i in range(min(len(text), len(labels))):
        if labels[i] == 1:
            spans.append('<span style="color:red;">'+text[i]+'</span>')
        else:
            spans.append('<span>'+text[i]+'</span>')
    
    return ''.join(spans)

def print_sent(text, labels):
    display(HTML(build_sent_html(text, labels)))

In [None]:
def print_compare_sent(s):
    tokens = s['tokens']
    labels = s['labels']
    print_sent(tokens, labels)
    
    px, py = parse_sent(s, sent_len, vec_dim, batch_size)
    pp = model(px).numpy()[0]

    pp[pp>=.5] = 1
    pp[pp<.5] = 0

    tokens = s['tokens']
    print_sent(tokens, pp)

In [None]:
def parse_sent(s, sent_len, vec_dim, batch_size):
    x = []
    y = []

    tokens = s['tokens']
    labels = s['labels']

    sent = []
    lb = []
    for t in range(sent_len):
        if t <= len(tokens) - 1:
            sent.append(get_vec(tokens[t]))
            lb.append(labels[t])
        else:
            sent.append([0.0]*vec_dim)
            lb.append(0)
    x.append(sent)
    y.append(lb)
    
    return np.array(x*batch_size), np.array(y*batch_size)

def predict_label(s):
    px, py = parse_sent(s, sent_len, vec_dim, batch_size)
    pp = model(px).numpy()[0]

    pp[pp>=.5] = 1
    pp[pp<.5] = 0
    
    return list(pp)

In [None]:
def compare_doc(sents):
    htmls = ''
    for s in sents:
        htmls += build_sent_html(s['tokens'], s['labels']) + '，'
    display(HTML(''.join(htmls)))
    
    print('\n')
    htmls = ''
    for s in sents:
        htmls += build_sent_html(s['tokens'], predict_label(s)) + '，'
    display(HTML(''.join(htmls)))

In [None]:
p = random.choice(sents2)
print_compare_sent(p)

In [None]:
doc = random.choice(docs2)
sents3 = get_sents(doc)
compare_doc(sents3)