### 在这里测试模型

In [None]:
import numpy as np
import tensorflow as tf
import keras
import keras.layers as layers
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array

In [None]:
def load_BERT_model(config_path='BERT/chinese_rbt6_L-6_H-768_A-12/bert_config_rbt6.json',
                    checkpoint_path = 'BERT/chinese_rbt6_L-6_H-768_A-12/bert_model.ckpt',
                    dict_path = 'BERT/chinese_rbt6_L-6_H-768_A-12/vocab.txt'):
    """ load BERT """
    tokenizer = Tokenizer(dict_path, do_lower_case=True)  # 建立分词器
    model = build_transformer_model(config_path, checkpoint_path)  # 建立模型，加载权重

    return tokenizer, model



tokenizer, model = load_BERT_model()


# 编码测试
token_ids, segment_ids = tokenizer.encode(u'')
token_ids, segment_ids = to_array([token_ids], [segment_ids])

print(token_ids, segment_ids)

print('\n ===== predicting =====\n')
print(model.predict([token_ids, segment_ids]))

# print('\n ===== reloading and predicting =====\n')
# model.save('test.model')
# del model
# model = keras.models.load_model('test.model')
# print(type(model))
# print(model.predict([token_ids, segment_ids]).shape)
model.summary()

In [None]:
def build_sentence_encoder(num_of_LSTMs=2):
    """ simple LSTM encoder
    
        sentence sequence |-> embedding of sentence
    
    """
    inputs = keras.Input(shape=[None, 768])
    sequence = inputs
    for _ in range(num_of_LSTMs):
        sequence, hidden_state, cell_state = layers.LSTM(768,return_state=True, return_sequences=True)(sequence)
    outputs = layers.concatenate([hidden_state, cell_state], axis=1)
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    return model


sent_encoder = build_sentence_encoder()
x = tf.random.normal([100, 10, 768])
sent_encoder(x).shape



In [None]:
lstm = sent_encoder.get_layer('lstm_41')
for weight in lstm.weights:
    print(weight.name, weight.shape)

lstm.set_weights([tf.random.normal([768, 3072]), tf.random.normal([768, 3072]), tf.random.normal([3072])])

In [109]:
def sim(sent1, sent2):
    """ similarity between two sentences
    
    naive choice is cosine similarity. could be replaced.
    """
    return np.dot(sent1, sent2)


def coverage_score(sentences, selection, alpha):
    """ the coverage score of a subset of sentences 
    
        it measure how much the selection covers all the information of the whole corpus.
    """
    N, S = len(sentences), len(selection)
    score = 0
    for i in range(N):
        score += tf.minimum(sum([sim(sentences[i], selection[j]) for j in range(S)]),
                            alpha * sum([sim(sentences[i], sentences[j]) for j in range(N)])
                            )

    return score


def document_extract(sentence_embeddings, sentence_lengths, max_budget, alpha=0.9, paragraph_partitions=None):
    """ extract key sentences from the whole corpus

    sentence_embeddings: dtype is numpy.ndarray ; dimension = [len_of_sequence, embedding_dim]
    sentence_lengths: a list of integers, each entry is the length of the corresponding sentence
    max_budget: the maximal length of the extracted text
    
    """
    N = len(sentence_embeddings)
    indices = set(range(N))
    selection = set()
    selection_length = 0
    while True:
        no_fit = True
        print(indices - selection)
        for i in indices - selection:
            if selection_length + sentence_lengths[i] > max_budget:
                continue

            selection_indices = list(selection.union(set([i])))
            if no_fit:    
                best_increment = coverage_score(sentence_embeddings, sentence_embeddings[selection_indices], alpha)
                best_id = i
            else:
                increment = coverage_score(sentence_embeddings, sentence_embeddings[selection_indices], alpha)
                if increment > best_increment:
                    best_increment = increment
                    best_id = i
            
            no_fit = False

        selection.add(best_id)
        selection_length += sentence_lengths[best_id]

        if no_fit:
            break

    return selection




In [110]:
sentences = np.random.randn(10, 768)
print(document_extract(sentences, [1] * 10, 6, alpha=1))

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
{1, 2, 3, 4, 5, 6, 7, 8, 9}
{1, 2, 3, 4, 5, 6, 8, 9}
{1, 2, 3, 4, 5, 6, 8}
{1, 2, 3, 4, 5, 8}
{1, 2, 3, 5, 8}
{8, 1, 2, 3}
{0, 4, 5, 6, 7, 9}


In [91]:
s = set()
s.union(set([1]))
s

set()