In [13]:
import numpy as np
import torch
import string
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import csv
import time

data_train_dir="snli_1.0_train.txt"
data_dev_dir="snli_1.0_dev.txt"
data_test_dir="snli_1.0_test.txt"
embedding_file_dir="Data/glove.6B.50d.txt"
worddict_dir='worddict.txt'

In [14]:
def read_data(data_dir):
    premise=[]
    hypothesis=[]
    labels=[] 
    labels_map={"entailment":0,"neutral":1,"contradiction":2}
    punct_table = str.maketrans({key: " " for key in string.punctuation})
    with open(data_dir,'r',encoding='utf-8') as lines:
        next(lines)
        for line in lines:
            line=line.strip().split('\t')
            if line[0] not in labels_map:   #忽略没有label的例子
                continue
            premise.append(line[5].translate(punct_table).lower())
            hypothesis.append(line[6].translate(punct_table).lower())
            labels.append(line[0])
    return {"premise":premise,
            "hypothesis":hypothesis,
            "labels":labels}    
def build_worddict(data):
    words=[]
    words.extend(["_PAD_","_OOV_","_BOS_","_EOS_"])
    for sentence in data["premise"]:
        words.extend(sentence.strip().split(" "))
    for sentence in data["hypothesis"]:
        words.extend(sentence.strip().split(" ")) 
    word_id={}
    id_word={}
    i=0
    for index,word in enumerate(words):
        if word not in word_id:
            word_id[word]=i
            id_word[i]=word
            i+=1
    #保存词典
    with open(worddict_dir, "w",encoding='utf-8') as f:
        for word in word_id:
            f.write("%s\t%d\n"%(word, word_id[word]))
    return word_id,id_word

def sentence2idList(sentence,word_id):
    ids=[]
    ids.append(word_id["_BOS_"])
    sentence=sentence.strip().split(" ")
    for word in sentence:
        if word not in word_id:
            ids.append(word_id["_OOV_"])
        else:
            ids.append(word_id[word])
    ids.append(word_id["_EOS_"])
    return ids

def data2id(data,word_id):
    premise_id=[]
    hypothesis_id=[]
    labels_id=[] 
    labels_map={"entailment":0,"neutral":1,"contradiction":2}
    for i,label in enumerate(data["labels"]):
        if label not in labels_map:   #忽略没有label的例子
            continue
        premise_id.append(sentence2idList(data["premise"][i],word_id))
        hypothesis_id.append(sentence2idList(data["hypothesis"][i],word_id))
        labels_id.append(labels_map[label])
            
    return {"premise_id":premise_id,
            "hypothesis_id":hypothesis_id,
            "labels_id":labels_id}    

def build_embeddings(embedding_file,word_id):
    #读取文件存入集合中
    embeddings_map={}
    with open(embedding_file,'r',encoding='utf-8') as f:
        for line in f:
            line=line.strip().split()
            word=line[0]
            if word in word_id:
                embeddings_map[word]=line[1:]   
    #放入矩阵中
    words_num = len(word_id)
    embedding_dim=len(embeddings_map['a'])
    embedding_matrix=np.zeros((words_num,embedding_dim))
    #print(words_num,embedding_dim)
    missed_cnt=0
    for i,word in enumerate(word_id):
        if word in embeddings_map:
            embedding_matrix[i]=embeddings_map[word]
        else:
            if word=="_PAD_":
                continue
            missed_cnt+=1
            embedding_matrix[i]=np.random.normal(size=embedding_dim)
    print("missed word count: %d"%(missed_cnt)) 
    return embedding_matrix
          
    

In [15]:
#读取数据
data_str=read_data(data_train_dir)
#构建词典
word_id,id_word=build_worddict(data_str)   
#清洗数据并转换为id
data_id=data2id(data_str,word_id)

In [19]:
data_str['premise'][0]

'a person on a horse jumps over a broken down airplane '

In [None]:
#保存 data_train_str和data_train_id
with open(data_train_str_dir,"wb") as f:
    pickle.dump(data_str,f)
with open(data_train_id_dir,"wb") as f:
    pickle.dump(data_id,f)

In [None]:
embedding_matrix=build_embeddings(embedding_file_dir,word_id)
print("embedding_matrix size: %d"%len(embedding_matrix))

In [None]:
with open(embedding_matrix_dir,"wb") as f:
    pickle.dump(embedding_matrix,f)

In [None]:
import torch
import pickle
from torch.utils.data import DataLoader
from model.SnliDataSet import SnliDataSet

worddict_dir='worddict.txt'
data_train_id_dir='train_data_id.pkl'
data_dev_id_dir='dev_data_id.pkl'
embedding_matrix_dir='embedding_matrix.pkl'
model_train_dir='saved_model\\train_model_'

#超参数
batch_size=1
use_gpu=True
patience=5

device=torch.device("cuda:0" if use_gpu else "cpu")

In [None]:
hidden_size=50
dropout=0.5
num_classes=3
lr=0.0004
epochs=2
max_grad_norm=10.0

In [None]:
# 加载数据
with open(data_train_id_dir,'rb') as f:
    train_data=SnliDataSet(pickle.load(f),max_premise_len=None,max_hypothesis_len=None)
train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True)

with open(data_dev_id_dir,'rb') as f:
    dev_data=SnliDataSet(pickle.load(f),max_premise_len=None,max_hypothesis_len=None)
dev_loader=DataLoader(dev_data,batch_size=batch_size,shuffle=False)

#加载embedding
with open(embedding_matrix_dir,'rb') as f:
    embeddings=torch.tensor(pickle.load(f),dtype=torch.float).to(device)

In [None]:
from model.esim import ESIM
model = ESIM(embeddings.shape[0],
             embeddings.shape[1],
             hidden_size,
             embeddings=embeddings,
             dropout=dropout,
             num_classes=num_classes,
             device=device).to(device)

In [None]:
#准备训练
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="max",factor=0.5,patience=0)


In [None]:
def getCorrectNum(probs, targets):
    _, out_classes = probs.max(dim=1)
    correct = (out_classes == targets).sum()
    return correct.item()

def train(model, data_loader, optimizer, criterion, max_gradient_norm):
    model.train()
    device=model.device
    
    time_epoch_start= time.time()
    running_loss=0 
    correct_cnt=0
    batch_cnt=0
    
    for index,batch in enumerate(data_loader):
        time_batch_start=time.time()
        #从data_loader中取出数据
        premises=batch["premises"].to(device)
        premises_len=batch["premises_len"].to(device)
        hypothesis=batch["hypothesis"].to(device)
        hypothesis_len=batch["hypothesis_len"].to(device)
        labels=batch["labels"].to(device)
        
        #梯度置0
        optimizer.zero_grad()
        
        #正向传播
        logits,probs=model(premises,premises_len,hypothesis,hypothesis_len)

        #求损失，反向传播，梯度裁剪，更新权重
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
        optimizer.step()
        
        running_loss+=loss.item()
        correct_cnt+=getCorrectNum(probs,labels)
        batch_cnt+=1
        print("Training  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt,len(data_loader),time.time()-time_batch_start, running_loss/(index+1)))
        
    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time,epoch_loss,epoch_accuracy



In [None]:
def validate(model, data_loader, criterion):
    model.eval()
    device=model.device
    
    time_epoch_start= time.time()
    running_loss=0 
    correct_cnt=0
    batch_cnt=0

    for index,batch in enumerate(data_loader):
        time_batch_start=time.time()
        #从data_loader中取出数据
        premises=batch["premises"].to(device)
        premises_len=batch["premises_len"].to(device)
        hypothesis=batch["hypothesis"].to(device)
        hypothesis_len=batch["hypothesis_len"].to(device)
        labels=batch["labels"].to(device)
        
        
        #正向传播
        logits,probs=model(premises,premises_len,hypothesis,hypothesis_len)

        #求损失
        loss = criterion(logits, labels)
        
        running_loss+=loss.item()
        correct_cnt+=getCorrectNum(probs,labels)
        batch_cnt+=1
        print("Testing  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt,len(data_loader),time.time()-time_batch_start, running_loss/(index+1)))
        
    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time,epoch_loss,epoch_accuracy



In [None]:
#训练过程中的参数
best_score=0.0
train_losses=[]
valid_losses=[]
patience_cnt=0

for epoch in range(epochs):
    #训练
    print("-"*50,"Training epoch %d"%(epoch),"-"*50)
    epoch_time,epoch_loss,epoch_accuracy =train(model,train_loader,optimizer,criterion,max_grad_norm)
    train_losses.append(epoch_loss)
    print("Training time: {:.4f}s, loss :{:.4f}, accuracy: {:.4f}%".format(epoch_time, epoch_loss, (epoch_accuracy*100)))
    
    #验证
    print("-"*50,"Validating epoch %d"%(epoch),"-"*50)
    epoch_time_dev, epoch_loss_dev, epoch_accuracy_dev = validate(model,dev_loader,criterion)
    valid_losses.append(epoch_loss_dev)
    print("Validating time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n".format(epoch_time_dev, epoch_loss_dev, (epoch_accuracy_dev*100)))
    
    #更新学习率
    scheduler.step(epoch_accuracy)
    
    #early stoping
    if epoch_accuracy_dev< best_score:
        patience_cnt+=1
    else:
        best_score=epoch_accuracy_dev
        patience_cnt=0
    if patience_cnt>=patience:
            print("-"*50,"Early stopping","-"*50)
            break
        
    #每个epoch都保存模型
    torch.save({"epoch": epoch,
                "model": model.state_dict(),
                "best_score": best_score,
                "train_losses": train_losses,
                "valid_losses": valid_losses},
               model_train_dir+str(epoch)+".dir")