In [1]:
from data_load import load_char_data, load_vocab
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset

## DSSM模型搭建

In [2]:
class DSSM(nn.Module):
    def __init__(self, CHAR_SIZE, embedding_size):
        super(DSSM, self).__init__()
        self.embedding=nn.Embedding(CHAR_SIZE,embedding_size)
        self.linear1=nn.Linear(embedding_size,256)
        self.linear2=nn.Linear(256,128)
        self.linear3=nn.Linear(128,64)
        self.dropout=nn.Dropout(p=0.2)
    
    def forward(self, a, b):
        a=self.embedding(a).sum(1)
        b=self.embedding(b).sum(1)
        
        a=torch.tanh(self.linear1(a))
        a=self.dropout(a)
        a=torch.tanh(self.linear2(a))
        a=self.dropout(a)
        a=torch.tanh(self.linear3(a))
        a=self.dropout(a)
        
        b=torch.tanh(self.linear1(b))
        b=self.dropout(b)
        b=torch.tanh(self.linear2(b))
        b=self.dropout(b)
        b=torch.tanh(self.linear3(b))
        b=self.dropout(b)
        
        cosine=torch.cosine_similarity(a,b,dim=1,eps=1e-8)
        return cosine
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

模型的一些参数，可根据实际需求增添，调整大小

In [3]:
EPOCH=50
BATCH_SIZE=50
LR=0.0005

In [7]:
CHAR_SIZE=10041
embedding_size=300

## 需定义自己的数据集，模板格式如下：

In [11]:
class MRPCDataset(Dataset):
    def __init__(self,filepath):
        self.path=filepath
        self.a_index, self.b_index, self.label=load_char_data(filepath)
        
    def __len__(self):
        return len(self.a_index)
    
    def __getitem__(self,idx):
        return self.a_index[idx],self.b_index[idx],self.label[idx]

In [12]:
dev_path='./MRPC/test_data.csv'
test4test=MRPCDataset(dev_path)
test4test[10]

(array([ 149, 9801, 4608, 8518, 7378, 2775, 5882, 7446, 5849, 7148,  121,
        7250, 8574, 4730, 1630, 5825, 7447, 2430, 4516, 4434, 7995,  255,
        4320,  434, 1937,  585, 5472, 8103, 8017, 6666, 1252, 3661, 7471,
        3210, 2680, 2135, 9330, 4277, 4307, 5804, 6718, 1370, 9089, 4350,
        5752, 3658, 6683, 3465, 7819,   92, 7714, 7716, 1573, 8165, 7784,
        6362, 5046, 7279, 2817,  702, 2729,  959,   92, 8011, 5479, 4369,
        2159,  672, 2963, 2353]),
 array([ 434, 1937, 4541, 4939,   45, 8282, 3157, 2776, 2738, 5279, 2342,
        1215, 6617, 8283, 9138, 3932,  705, 4414, 3547, 8632,  149, 9801,
        4608, 8518, 7378, 2775, 5882, 7446, 5849, 7148,  121, 7250, 8574,
        4730, 4839,  433,  325, 2695, 3785,  980, 4839, 8373, 3852, 1937,
        7279, 2817, 8909, 1370, 6723, 8103, 8017, 1370, 9089, 4350, 5752,
        3658, 6683, 3465, 7819,   92, 7714, 7716, 1573, 8367, 6000, 5281,
        9337, 5588, 1894, 6930]),
 1)

In [13]:
data_root='./MRPC/'
train_path=data_root+'train_data.csv'
test_path=data_root+'test_data.csv'

In [14]:
# 创建数据集并创立数据载入器
train_data=MRPCDataset(train_path)
test_data=MRPCDataset(test_path)
tarin_loader=DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader=DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
dssm=DSSM(CHAR_SIZE=CHAR_SIZE, embedding_size=embedding_size).to(device)

optimizer=torch.optim.Adam(dssm.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss()

In [19]:
for epoch in range(EPOCH):
    for step, (text_a, text_b, label) in enumerate(tarin_loader):
        a = text_a.to(device).long()
        b = text_b.to(device).long()
        l = torch.LongTensor(label).to(device)
        
        pos_res = dssm(a,b)
        neg_res = l-pos_res
        
        out = torch.stack([neg_res, pos_res],1).to(device)
        loss = loss_func(out,l)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if(step+1) % 20 == 0:
            total=0
            correct=0
            for (test_a, test_b, test_l) in test_loader:
                tst_a = test_a.to(device).long()
                tst_b = test_b.to(device).long()
                tst_l = torch.LongTensor(test_l).to(device)
                
                pos_res = dssm(tst_a, tst_b)
                neg_res = 1-pos_res
                out=torch.max(torch.stack([neg_res, pos_res],1).to(device), 1)[1]
                if out.size()==tst_l.size():
                    total+=tst_l.size(0)
                    correct+=(out==tst_l).sum().item()
            print('[Epoch]:',epoch+1,'训练loss',loss.item())
            print('[Epoch]:',epoch+1,'测试集准确率',(correct*1.0/total))

[Epoch]: 1 训练loss 0.8622357249259949
[Epoch]: 1 测试集准确率 0.6585507246376812
[Epoch]: 1 训练loss 0.806443989276886
[Epoch]: 1 测试集准确率 0.6707246376811594
[Epoch]: 1 训练loss 0.8646817207336426
[Epoch]: 1 测试集准确率 0.6742028985507247
[Epoch]: 1 训练loss 0.8290250897407532
[Epoch]: 1 测试集准确率 0.6759420289855073
[Epoch]: 2 训练loss 0.795433521270752
[Epoch]: 2 测试集准确率 0.6840579710144927
[Epoch]: 2 训练loss 0.7097998857498169
[Epoch]: 2 测试集准确率 0.6747826086956522
[Epoch]: 2 训练loss 0.7851957678794861
[Epoch]: 2 测试集准确率 0.663768115942029
[Epoch]: 2 训练loss 0.8129715919494629
[Epoch]: 2 测试集准确率 0.6556521739130434
[Epoch]: 3 训练loss 0.6672949194908142
[Epoch]: 3 测试集准确率 0.6423188405797101
[Epoch]: 3 训练loss 0.746493399143219
[Epoch]: 3 测试集准确率 0.6678260869565218
[Epoch]: 3 训练loss 0.7733408212661743
[Epoch]: 3 测试集准确率 0.6614492753623188
[Epoch]: 3 训练loss 0.8352243900299072
[Epoch]: 3 测试集准确率 0.673623188405797
[Epoch]: 4 训练loss 0.8113890886306763
[Epoch]: 4 测试集准确率 0.64
[Epoch]: 4 训练loss 0.7458502054214478
[Epoch]: 4 测试集准确率 0.

[Epoch]: 28 训练loss 0.375754714012146
[Epoch]: 28 测试集准确率 0.6295652173913043
[Epoch]: 28 训练loss 0.4540126323699951
[Epoch]: 28 测试集准确率 0.6185507246376811
[Epoch]: 29 训练loss 0.46712660789489746
[Epoch]: 29 测试集准确率 0.6452173913043479
[Epoch]: 29 训练loss 0.4109412133693695
[Epoch]: 29 测试集准确率 0.6324637681159421
[Epoch]: 29 训练loss 0.42340928316116333
[Epoch]: 29 测试集准确率 0.6411594202898551
[Epoch]: 29 训练loss 0.4350065290927887
[Epoch]: 29 测试集准确率 0.6365217391304347
[Epoch]: 30 训练loss 0.48754069209098816
[Epoch]: 30 测试集准确率 0.6365217391304347
[Epoch]: 30 训练loss 0.3937561511993408
[Epoch]: 30 测试集准确率 0.647536231884058
[Epoch]: 30 训练loss 0.4726126492023468
[Epoch]: 30 测试集准确率 0.6307246376811594
[Epoch]: 30 训练loss 0.42946183681488037
[Epoch]: 30 测试集准确率 0.6307246376811594
[Epoch]: 31 训练loss 0.5526639223098755
[Epoch]: 31 测试集准确率 0.624927536231884
[Epoch]: 31 训练loss 0.48910513520240784
[Epoch]: 31 测试集准确率 0.6330434782608696
[Epoch]: 31 训练loss 0.39372554421424866
[Epoch]: 31 测试集准确率 0.6394202898550725
[Epoch]: 

In [20]:
torch.save(dssm,'./dssm.pkl')

In [21]:
if __name__ == '__main__':
    test_data=MRPCDataset(test_path)
    test_loader=DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dssm = torch.load('./dssm.pkl').to(device)
    
    total=0
    correct=0
    TP, TN, FP, FN = 0, 0, 0, 0
    FLAG=True
    for (test_a, test_b, test_l) in test_loader:
        tst_a = test_a.to(device).long()
        tst_b = test_b.to(device).long()
        tst_l = torch.LongTensor(test_l).to(device)

        pos_res = dssm(tst_a, tst_b)
        neg_res = 1 - pos_res
        out = torch.max(torch.stack([neg_res, pos_res], 1).to(device), dim=1)[1]

        total += tst_l.size(0)
        correct += (out == tst_l).sum().item()

        #计算精确率、召回率
        TP += ((out == 1) & (tst_l == 1)).sum().item()
        TN += ((out == 0) & (tst_l == 0)).sum().item()
        FN += ((out == 0) & (tst_l == 1)).sum().item()
        FP += ((out == 1) & (tst_l == 0)).sum().item()

        if FLAG == True:
            for i in range(30,40):
                a, b, l = test_data[i][0], test_data[i][1], test_data[i][2]
                print('标签：',l,'预测：',out[i].item())
        FLAG=False

    p = TP / (TP + FP)
    r = TP / (TP + FN)

    print('测试集准确率: ', (correct * 1.0 / total))
    print('测试集精确率：', p)
    print('测试集召回率：', r)
    print('测试集f1-score：', 2 * r * p / (r + p))

标签： 1 预测： 1
标签： 0 预测： 1
标签： 0 预测： 1
标签： 0 预测： 1
标签： 1 预测： 1
标签： 0 预测： 1
标签： 1 预测： 0
标签： 1 预测： 0
标签： 0 预测： 1
标签： 1 预测： 1
测试集准确率:  0.6208695652173913
测试集精确率： 0.7126833477135461
测试集召回率： 0.7201394943330427
测试集f1-score： 0.7163920208152647
