In [3]:
import torch
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from cnpre import CNPRE, miniCNPRE
from model import LSTM
from bilstm import bilstm_crf
import data
import sys

def batch_padding(batch):
    maxlen = 0
    for item in batch:
        maxlen = max(maxlen, len(item[0]))
    tags = []
    anss = []
    lens = []
    for item in batch:
        lens.append(len(item[0]))
        for i in range(len(item[0]), maxlen):
            item[0].append(word2idx['<pad>'])
            item[1].append(word2idx['<pad>'])
        tags.append(item[0])
        anss.append(item[1])
    tags = torch.LongTensor(tags)
    anss = torch.LongTensor(anss)
    lens = torch.IntTensor(lens)
    return tags, anss, maxlen, lens

# 超参数
embedding_dim = 300
lstm_hidden_dim = 300
lr = 1e-3
batch_size = 32

# 建立语言模型
word2idx = torch.load('wordidx.pt')['word2idx']
idx2word = torch.load('wordidx.pt')['idx2word']
print('building model...', end='')
sys.stdout.flush()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 计算设备
pad_index = word2idx['<pad>']
lstm = LSTM(len(word2idx), embedding_dim, lstm_hidden_dim, batch_size, pad_index)
lstm.set_device(device)
print('\tdone.')
# 读取参数
checkpoint = torch.load('./checkpoint/cn-pre-check-point lr=1e-4.pt') # 检查点保存路径
lstm.load_state_dict(checkpoint['net'])

# 建立分词模型
embedding_dim = 256
lstm_hidden_dim = 128
w2idx, train_set = data.get_data('train')
tag2idx = data.tag2idx
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 计算设备
print('building model...', end='')
model = bilstm_crf(tag2idx, len(w2idx), embedding_dim, lstm_hidden_dim)
model.set_device(device)
print('\tdone.')
# 读取参数
checkpoint = torch.load('./checkpoint/lr=1e-5.pt')
model.load_state_dict(checkpoint['net'])

building model...	done.
building model...	done.


<All keys matched successfully>

In [66]:
from random import randint
words = []
lstm.eval()
begin = input('句子开头：')
sentence = [w2idx[word] if word in w2idx else 0 for word in begin]
sentence = torch.LongTensor(sentence).to(device)
predict_tag = torch.LongTensor(model.test(sentence, 1)[0])
s = begin
sp = []
words = []
for index in range(len(predict_tag)):
    if predict_tag[index]==data.tag2idx['S'] or predict_tag[index]==data.tag2idx['B']:
        sp.append(index)
sp.append(len(predict_tag))
for i in range(len(sp)):
    if i != 0:
        words.append(s[sp[i-1]:sp[i]])
        
print('分词结果：{}'.format(words))
curword = word2idx[words[-1]] if words[-1] in word2idx else randint(0, len(word2idx))
#for i in range(20):
h = lstm.h0
c = lstm.c0
while True:
    curword, h, c = lstm.pretend(curword, h, c)
    curword = curword.item()
    if curword == word2idx['<eos>']:
        break
    words.append(idx2word[curword])
sentence = ""
for word in words:
    sentence += word
print('生成句子：{}'.format(sentence))

句子开头： 自然语言处理


分词结果：['自然', '语言', '处理']
生成句子：自然语言处理好业主与产品结构不相分的双星相比，生产的主要分项研究等都在相当大的程度上是主要的。
