In [None]:
import os
import logging
import tensorflow_hub as hub
logging.basicConfig(level=logging.INFO)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from bert_serving.client import BertClient
bc = BertClient()

In [None]:
import datahelper
datahelper = datahelper.DataHelper(embedding_path="../embedding/STCWiki/STCWiki_mincount0.model.bin")

In [None]:
REMOVE_STOPWORDS = False
TO_LOWER = True
TOKEN_TYPE = 'nltk'
EMB = 'stc' # glove or stc

In [None]:
trainX, trainX_bert, trainND, trainDQ, train_turns, train_masks = datahelper.get_model_train_data(
    'train',
    TOKEN_TYPE, 
    REMOVE_STOPWORDS, 
    TO_LOWER,
    EMB,
    bert=True,
)

devX, devX_bert, devND, devDQ, dev_turns, dev_masks = datahelper.get_model_train_data(
    'dev',
    TOKEN_TYPE, 
    REMOVE_STOPWORDS, 
    TO_LOWER,
    EMB,
    bert=True,
)

testX, testX_bert, test_turns, test_masks = datahelper.get_model_test_data(
    TOKEN_TYPE, 
    REMOVE_STOPWORDS, 
    TO_LOWER,
    EMB,
    bert=True,
)

In [None]:
trainX.shape, trainX_bert.shape

In [None]:
devX.shape, devX_bert.shape

In [None]:
import pickle
pickle.dump(trainX_bert, open('trainX_bert_512_sent.p', 'wb'))
pickle.dump(devX_bert, open('devX_bert_512_sent.p', 'wb'))
pickle.dump(testX_bert, open('testX_bert_512_sent.p', 'wb'))

In [None]:
trainX = pickle.load(open('trainX_bert.p', 'rb'))
devX = pickle.load(open('devX_bert.p', 'rb'))
testX = pickle.load(open('testX_bert.p', 'rb'))

In [None]:
devX_berts.shape, devX_bert.shape

### TensorFlow Hub

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import bert
import os
from bert import run_classifier
from bert import optimization
from bert import tokenization
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
bert_hub_model_handle = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"

In [None]:
def get_sess_config():
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True
    return sess_config

def create_tokenizer_from_hub_module(bert_hub_model_handle, sess_config = get_sess_config()):
    """Get the vocab file and casing info from the Hub module."""
    with tf.Graph().as_default():
        bert_module = hub.Module(bert_hub_model_handle, trainable=True)
        tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
        with tf.Session(config=sess_config) as sess:
            vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
                                            tokenization_info["do_lower_case"]])
    return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)

In [None]:
tokenizer = create_tokenizer_from_hub_module(bert_hub_model_handle)

In [None]:
# Get input_id
sent = 'You may try restarting the Wechat app.'
bert_tokens = []
bert_tokens.append("[CLS]")
bert_tokens.extend(tokenizer.tokenize(sent))
bert_tokens.append("[SEP]")
input_ids = tokenizer.convert_tokens_to_ids(bert_tokens)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)

In [None]:
bert_tokens

In [None]:
input_ids

In [None]:
input_mask

In [None]:
segment_ids

In [None]:
bert_inputs = dict(
    input_ids=[input_ids],
    input_mask=[input_mask],
    segment_ids=[segment_ids]
)

with tf.Graph().as_default():
    bert_module = hub.Module(bert_hub_model_handle, trainable=True)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session(config=get_sess_config()) as sess:
        bert_outputs = bert_module(bert_inputs, signature="tokens", as_dict=True)
        pooled_output = bert_outputs["pooled_output"]
        sequence_output = bert_outputs["sequence_output"]

### BERT Preprocess

In [None]:
import pickle
train_corpus = pickle.load(open("train_corpus.p", "rb"))
dev_corpus = pickle.load(open("dev_corpus.p", "rb"))
test_corpus = pickle.load(open("test_corpus.p", "rb"))

In [None]:
trainX = [train_corpus[i][2] for i in range(len(train_corpus))]
devX = [dev_corpus[i][2] for i in range(len(dev_corpus))]
testX = [train_corpus[i][2] for i in range(len(test_corpus))]

In [None]:
max_len = 150
max_sent = 7

In [None]:
def get_bert_tokens(utterance):
    global max_len
    bert_tokens = []
    bert_tokens.append("[CLS]")
    bert_tokens.extend(tokenizer.tokenize(utterance))
    bert_tokens.append("[SEP]")
    input_ids = tokenizer.convert_tokens_to_ids(bert_tokens)
    return input_ids

In [None]:
def convert_corpus_to_bert_tokens(X):
    input_ids = []
    for dialogue in X:
        utterances = []
        for utterance in dialogue:
            utterances.append(get_bert_tokens(utterance))
        input_ids.append(utterances.copy())
    return input_ids

In [None]:
def get_max_len(X):
    max_len = 0
    count = 0
    for dialogue in X:
        for utterance in dialogue:
            count = count + 1 if len(utterance) > 150 else count
            max_len = max(len(utterance), max_len)
    print(count / len(X))
    return max_len

In [None]:
def get_maskid_seqmentid(X):
    global max_len
    dialogue_masks = []
    dialogue_segids = []
    for dialogue in X:
        utterance_masks = []
        utterance_segids = []
        for utterance in dialogue:
            seqlen = min(max_len, len(utterance))
            input_mask = [1] * seqlen + [0] * (max_len - seqlen)
            seg_id = [0] * max_len
            utterance_masks.append(input_mask.copy())
            utterance_segids.append(seg_id.copy())
        dialogue_masks.append(utterance_masks.copy())
        dialogue_segids.append(utterance_segids.copy())
    return dialogue_masks, dialogue_segids

In [None]:
trainX_input_ids = convert_corpus_to_bert_tokens(trainX)
devX_input_ids = convert_corpus_to_bert_tokens(devX)
testX_input_ids = convert_corpus_to_bert_tokens(testX)

In [None]:
max(get_max_len(trainX_input_ids), get_max_len(devX_input_ids), get_max_len(testX_input_ids))

In [None]:
trainX_input_masks, trainX_segment_ids = get_maskid_seqmentid(trainX_input_ids)
devX_input_masks, devX_segment_ids = get_maskid_seqmentid(devX_input_ids)
testX_input_masks, testX_segment_ids = get_maskid_seqmentid(testX_input_ids)

In [None]:
len(trainX_input_ids[0][0]), len(trainX_input_masks[0][0]), len(trainX_segment_ids[0][0])

In [None]:
def bert_padding(input_ids, input_masks, segment_ids):
    global max_sent
    padding = [0] * max_len
    for i in range(len(input_ids)):
        while len(input_ids[i]) < max_sent:
            input_ids[i].append(padding)
            
        while len(input_masks[i]) < max_sent:
            input_masks[i].append(padding)
        
        while len(segment_ids[i]) < max_sent:
            segment_ids[i].append(padding)
            
        for j in range(len(input_ids[i])):
            seqlen = len(input_ids[i][j])
            if seqlen > max_len:
                input_ids[i][j] = input_ids[i][j][:max_len]
            if seqlen < max_len:
                input_ids[i][j].extend([0] * (max_len - seqlen))
                
    return input_ids, input_masks, segment_ids

In [None]:
trainX_input_ids_pad, trainX_input_masks_pad, trainX_segment_ids_pad = bert_padding(
    trainX_input_ids, 
    trainX_input_masks, 
    trainX_segment_ids,
)

devX_input_ids_pad, devX_input_masks_pad, devX_segment_ids_pad = bert_padding(
    devX_input_ids, 
    devX_input_masks, 
    devX_segment_ids,
)

testX_input_ids_pad, testX_input_masks_pad, testX_segment_ids_pad = bert_padding(
    testX_input_ids, 
    testX_input_masks, 
    testX_segment_ids,
)

In [None]:
for dialogue in testX_segment_ids_pad:
    assert len(dialogue) == 7
    for utterance in dialogue:
        assert len(utterance) == 150

In [None]:
import numpy as np
np.asarray(trainX_input_ids_pad[0]).shape, np.asarray(trainX_input_masks_pad[0]).shape, np.asarray(trainX_segment_ids_pad[0]).shape

In [None]:
unstacked_input_ids = []
unstacked_input_mask = []
unstacked_segment_ids = []

for _id, mask, segid in zip(trainX_input_ids_pad[0:3], trainX_input_masks_pad[0:3], trainX_segment_ids_pad[0:3]):
    unstacked_input_ids.extend(_id)
    unstacked_input_mask.extend(mask)
    unstacked_segment_ids.extend(segid)

In [None]:
bert_inputs = dict(
    input_ids=trainX_input_ids_pad[0],
    input_mask=trainX_input_masks_pad[0],
    segment_ids=trainX_segment_ids_pad[0],
)

with tf.Graph().as_default():
    bert_module = hub.Module(bert_hub_model_handle, trainable=True)
    tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
    with tf.Session(config=get_sess_config()) as sess:
        bert_outputs = bert_module(bert_inputs, signature="tokens", as_dict=True)
        pooled_output = bert_outputs["pooled_output"]
        sequence_output = bert_outputs["sequence_output"]

In [None]:
pooled_output.shape

In [None]:
pickle.dump(trainX_input_ids_pad, open('trainX_input_ids.p', 'wb'))
pickle.dump(devX_input_ids_pad, open('devX_input_ids.p', 'wb'))
pickle.dump(testX_input_ids_pad, open('testX_input_ids.p', 'wb'))
pickle.dump(trainX_input_masks_pad, open('trainX_input_masks.p', 'wb'))
pickle.dump(devX_input_masks_pad, open('devX_input_masks.p', 'wb'))
pickle.dump(testX_input_masks_pad, open('testX_input_masks.p', 'wb'))
pickle.dump(trainX_segment_ids_pad, open('trainX_segment_ids.p', 'wb'))
pickle.dump(devX_segment_ids_pad, open('devX_segment_ids.p', 'wb'))
pickle.dump(testX_segment_ids_pad, open('testX_segment_ids.p', 'wb'))

## Test shape

In [None]:
import pickle
import numpy as np
import tensorflow_hub as hub
trainX_input_ids = pickle.load(open('trainX_input_ids.p', 'rb'))
trainX_input_masks = pickle.load(open('trainX_input_masks.p', 'rb'))
trainX_segment_ids = pickle.load(open('trainX_segment_ids.p', 'rb'))
devX_input_ids = pickle.load(open('devX_input_ids.p', 'rb'))
devX_input_masks = pickle.load(open('devX_input_masks.p', 'rb'))
devX_segment_ids = pickle.load(open('devX_segment_ids.p', 'rb'))
testX_input_ids = pickle.load(open('testX_input_ids.p', 'rb'))
testX_input_masks = pickle.load(open('testX_input_masks.p', 'rb'))
testX_segment_ids = pickle.load(open('testX_segment_ids.p', 'rb'))

In [None]:
bert_module = hub.Module("https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1", trainable=True)

In [None]:
def getshape(l):
    return np.array(l).shape

In [None]:
input_ids = trainX_input_ids[0:3]
input_mask = trainX_input_masks[0:3]
segment_ids = trainX_segment_ids[0:3]
print(getshape(input_ids), getshape(input_mask), getshape(segment_ids))

for idx, (dialog_ids, dialog_masks, dialog_segids) in enumerate(zip(input_ids, input_mask, segment_ids)):
    for _id, mask, segid in zip(dialog_ids, dialog_masks, dialog_segids):
        print(getshape(_id), getshape(mask), getshape(segid))
        print(getshape([_id]))
        bert_inputs = dict(
            input_ids=[_id],
            input_mask=[mask],
            segment_ids=[segid],
        )
        
    


In [None]:
bert_outputs = bert_module(bert_inputs, signature="tokens", as_dict=True)

In [None]:
pooled_output = bert_outputs["pooled_output"]

In [None]:
pooled_output