In [1]:
import sys
sys.path.append('../')

In [2]:
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import pickle
import json
import torch
from copy import deepcopy
from pytorch_nlp_models.text_pair.siamese_rnn import SiameseGRU
from utils.preprocess import text2ids
# from utils.datasets import LCQMCDataset
# from utils.model_utils import model_train, model_eval

# from torch.utils.data import DataLoader
# from dumb_containers import evaluate_performance

In [3]:
DATA_PATH = '../data/'
# LCQMC_PATH = os.path.join(DATA_PATH, 'LCQMC')
WORD_VECTORS_PATH = os.path.join(DATA_PATH, 'word_vectors')
BAIDUBAIKE_PKL = os.path.join(WORD_VECTORS_PATH, 'baidubaike.pkl')

MAX_SEQ_LEN = 40

MODEL_PATH = os.path.join(DATA_PATH, 'model_files/siamese_gru')
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)
    
MODEL_FILE = os.path.join(MODEL_PATH, 'model.pkl')

MODEL_CONFIG_JSON = os.path.join(MODEL_PATH, 'config.json')

with open(MODEL_CONFIG_JSON, 'r') as f:
    MODEL_CONFIG = json.load(f)

In [4]:
with open(BAIDUBAIKE_PKL, 'rb') as f:
    wvs = pickle.load(f)
    
wi = wvs['wi']
iw = wvs['iw']
dim = wvs['dim']
emb = wvs['emb']

In [5]:
model = SiameseGRU(**MODEL_CONFIG
                  )
checkpoint = torch.load(MODEL_FILE, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [6]:
model.emb.state_dict()

OrderedDict([('weight',
              tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.0110,  0.2647,  0.4712,  ..., -0.1420,  0.5493,  0.4880],
                      [-0.1045, -0.4096,  0.0025,  ...,  0.2424,  0.5210,  0.0380],
                      ...,
                      [ 0.1317, -0.0819,  0.0877,  ..., -0.0862, -0.0418, -0.1139],
                      [ 0.0918,  0.1966, -0.0043,  ..., -0.1252,  0.0385,  0.0049],
                      [ 0.0351,  0.1157, -0.0244,  ..., -0.0970,  0.0307, -0.0839]]))])

In [7]:
def sim(text1, text2):
    ids1 = text2ids(text1, wi, charmode=True)
    ids2 = text2ids(text2, wi, charmode=True)
    len1 = len(ids1)
    len2 = len(ids2)
    
    if len(ids1) > MAX_SEQ_LEN:
        ids1 = ids1[:MAX_SEQ_LEN]
        len1 = MAX_SEQ_LEN
    else:
        len1 = len(ids1)
        ids1 += [0] * (MAX_SEQ_LEN - len1)

    if len(ids2) > MAX_SEQ_LEN:
        ids2 = ids2[:MAX_SEQ_LEN]
        len2 = MAX_SEQ_LEN
    else:
        len2 = len(ids2)
        ids2 += [0] * (MAX_SEQ_LEN - len2)

    ids1_tensor = torch.tensor([ids1], dtype = torch.long)
    ids2_tensor = torch.tensor([ids2], dtype = torch.long)
    len1_tensor = torch.tensor([len1], dtype = torch.long)
    len2_tensor = torch.tensor([len2], dtype = torch.long)
    model.eval()
    with torch.no_grad():
        logits, vec1, vec2 = model(ids1_tensor, ids2_tensor, len1_tensor, len2_tensor)
        probs = torch.softmax(logits, dim = 1)
    print(probs)
    return probs[0, 1].item(), vec1, vec2

# 一般测试

In [8]:
text1 = "英雄联盟什么英雄最好"
text2 = "英雄联盟最好英雄是什么"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.0250, 0.9750]])


0.9750417470932007

In [9]:
text1 = "我很高兴"
text2 = "我很开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9964, 0.0036]])


0.0035700732842087746

In [10]:
text1 = "我很开心"
text2 = "我很高兴"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9964, 0.0036]])


0.0035700732842087746

In [11]:
text1 = "我很高兴"
text2 = "我特别特别开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9986, 0.0014]])


0.0014140098355710506

In [12]:
text1 = "我很高兴"
text2 = "我其实觉得自己很开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9980, 0.0020]])


0.0019592756871134043

In [13]:
text1 = "我特别特别开心"
text2 = "我其实觉得自己很开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9565, 0.0435]])


0.043486587703228

In [14]:
text1 = "我很高兴"
text2 = "我不开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9992e-01, 7.5343e-05]])


7.53425556467846e-05

In [15]:
text1 = "我很高兴"
text2 = "我不高兴"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9537, 0.0463]])


0.046322498470544815

In [16]:
text1 = "我很高兴"
text2 = "我很高兴"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.0056, 0.9944]])


0.9944317936897278

In [17]:
text1 = "我很伤心"
text2 = "我很难过"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9976, 0.0024]])


0.0023956228978931904

In [18]:
text1 = "真好"
text2 = "不错"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9975, 0.0025]])


0.002549131866544485

In [19]:
text1 = "高兴"
text2 = "开心"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9970e-01, 2.9801e-04]])


0.00029800814809277654

In [20]:
text1 = "大家好才是真的好"
text2 = "大家好才是真的好"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.0035, 0.9965]])


0.9965393543243408

In [21]:
text1 = "为什么能开出腾讯信用却没有微粒贷朋友的没用腾讯信用却有30000的额度呢"
text2 = "我钱包里没有你们这个应用"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9949e-01, 5.1214e-04]])


0.0005121392896398902

In [22]:
text1 = "我也不知道"
text2 = "好吧"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9986, 0.0014]])


0.0014128254260867834

In [23]:
text1 = "深度学习"
text2 = "机器学习"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9833, 0.0167]])


0.0167356226593256

In [24]:
text1 = "机器学习"
text2 = "深度学习"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9833, 0.0167]])


0.0167356226593256

In [25]:
text1 = "人民团体是什么"
text2 = "人民团体是指"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.0541, 0.9459]])


0.9458606243133545

# 百度-车牌，不太一致

In [26]:
text1 = "车头如何放置车牌"
text2 = "前牌照怎么装"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9849, 0.0151]])


0.015069725923240185

In [27]:
text1 = "车头如何放置车牌"
text2 = "如何办理北京车牌"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9971, 0.0029]])


0.002867944072932005

In [28]:
text1 = "车头如何放置车牌"
text2 = "后牌照怎么装"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9921, 0.0079]])


0.00785581860691309

# 百度-信号

In [29]:
text1 = "信号忽强忽弱"
text2 = "信号忽高忽低"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.2675, 0.7325]])


0.7324687242507935

In [30]:
text1 = "信号忽强忽弱"
text2 = "信号忽左忽右"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.5458, 0.4542]])


0.4542343318462372

In [31]:
text1 = "信号忽强忽弱"
text2 = "信号忽然中断"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.7430, 0.2570]])


0.2570425868034363

# 百度-机器学习

In [32]:
text1 = "如何学好深度学习"
text2 = "深入学习习近平讲话材料"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9981, 0.0019]])


0.001922023482620716

In [33]:
text1 = "如何学好深度学习"
text2 = "机器学习教程"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9923e-01, 7.6823e-04]])


0.000768233323469758

In [34]:
text1 = "如何学好深度学习"
text2 = "人工智能教程"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9997e-01, 2.9498e-05]])


2.949829831777606e-05

# 百度-香蕉的翻译，偏小但排序一致

In [35]:
text1 = "香蕉的翻译"
text2 = "香蕉用英文怎么说"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9821, 0.0179]])


0.01794440858066082

In [36]:
text1 = "香蕉的翻译"
text2 = "香蕉怎么吃"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9744, 0.0256]])


0.025603730231523514

In [37]:
text1 = "香蕉的翻译"
text2 = "桔子用英文怎么说"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9977e-01, 2.3198e-04]])


0.00023197510745376348

# 百度-腹泻，排序有差别

In [38]:
text1 = "小儿腹泻偏方"
text2 = "宝宝拉肚子偏方"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9985, 0.0015]])


0.0015421390999108553

In [39]:
text1 = "小儿腹泻偏方"
text2 = "小儿感冒偏方"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9914e-01, 8.5903e-04]])


0.0008590331999585032

In [40]:
text1 = "小儿腹泻偏方"
text2 = "腹泻偏方"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.0895, 0.9105]])


0.9104790687561035

# 百度-LOL，数值偏小，但排序一致

In [41]:
text1 = "英雄联盟好玩吗，怎么升级"
text2 = "英雄联盟攻略"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9855, 0.0145]])


0.014453393407166004

In [42]:
text1 = "英雄联盟好玩吗，怎么升级"
text2 = "英雄联盟服务器升级"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9672, 0.0328]])


0.032777220010757446

In [43]:
text1 = "英雄联盟好玩吗，怎么升级"
text2 = "怎么打好英雄联盟"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.7560, 0.2440]])


0.24402819573879242

# 百度-红米

In [44]:
text1 = "红米更新出错"
text2 = "红米升级系统出错"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9206, 0.0794]])


0.07939153909683228

In [45]:
text1 = "红米更新出错"
text2 = "红米账户出错"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.7932, 0.2068]])


0.206807941198349

In [46]:
text1 = "红米更新出错"
text2 = "如何买到小米手机"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9982e-01, 1.8473e-04]])


0.00018473295494914055

# 百度-李彦宏

In [47]:
text1 = "李彦宏是百度公司创始人"
text2 = "百度是李彦宏创办的"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.4403, 0.5597]])


0.5596696734428406

In [48]:
text1 = "李彦宏是百度公司创始人"
text2 = "马化腾创办了腾讯公司"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9960, 0.0040]])


0.0040373411029577255

In [49]:
text1 = "李彦宏是百度公司创始人"
text2 = "姚明是NBA的著名球星"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9981e-01, 1.9472e-04]])


0.00019472363055683672

# 百度-中国历史

In [50]:
text1 = "中国有五千年的历史"
text2 = "中国是个历史悠久的国家"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9994e-01, 5.5082e-05]])


5.508239337359555e-05

In [51]:
text1 = "中国有五千年的历史"
text2 = "中国有很多少数民族"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9995e-01, 4.6088e-05]])


4.608769086189568e-05

In [52]:
text1 = "中国有五千年的历史"
text2 = "中国有13亿人口"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9999e-01, 5.4555e-06]])


5.455455720948521e-06

# 百度-北京奥运会，偏小，但数值一致

In [53]:
text1 = "北京成功申办了2008年奥运会"
text2 = "2008年奥运会在北京举行"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[0.9911, 0.0089]])


0.00885780993849039

In [54]:
text1 = "北京成功申办了2008年奥运会"
text2 = "伦敦奥运会在2012年举行"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9995e-01, 4.8986e-05]])


4.898626502836123e-05

In [55]:
text1 = "北京成功申办了2008年奥运会"
text2 = "东京奥运会即将举办"
pos_prob, vec1, vec2 = sim(text1, text2)
pos_prob

tensor([[9.9968e-01, 3.1851e-04]])


0.000318505015457049