修改了 data.py的 batch_yeild
原函数 like def batch_yield(data, batch_size, vocab, tag2label, shuffle=False): """
:param data:
:param batch_size:
:param vocab:
:param tag2label:
:param shuffle:
:return:
"""
if shuffle:
random.shuffle(data)
seqs, labels = [], []
for (sent_, tag_) in data:
sent_ = sentence2id(sent_, vocab)
label_ = [tag2label[tag] for tag in tag_]
if len(seqs) == batch_size:
yield seqs, labels
seqs, labels = [], []
seqs.append(sent_)
labels.append(label_)
if len(seqs) != 0:
yield seqs, labels
要修改label_ like https://www.tensorflow.org/tutorials/text/text_generation
add pre_train mode in main.py