BiLSTM序列标注---盲汉翻译

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random # for dataset shuffling
import openpyxl # for recording experimental value
import copy # for data shuffling

https://blog.csdn.net/vivian_ll/article/details/93894151

In [17]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # LSTM以word_embeddings作为输入, 输出维度为 hidden_dim 的隐藏状态值
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # 线性层将隐藏状态空间映射到标注空间
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # 各个维度的含义是 (num_layers*num_directions, batch_size, hidden_dim)
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(
            embeds.view(len(sentence), 1, -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

In [18]:
#Making training data
original_data = []
t = open('./对应拼音标调.txt',"r", encoding='UTF8')
f = open("./训练集.txt", "rb")
#计数 数一共有多少行
lines = f.readlines()
tags = t.readlines()

# Training data: 6692(total number of sentences)*2(line&tag)
for index, line in enumerate(lines):
    line = lines[index].rstrip().split()
    tag = tags[index].rstrip().split()
    original_data.append((line, tag))

In [19]:
word_to_ix = {}
for sent, tags in original_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"0":0, "1":1, "2":2, "3":3, "4":4, ",":5, ".":6, "?":7,"!":8}

np.save('word_to_ix.npy',word_to_ix)
np.save('tag_to_ix.npy',tag_to_ix)

In [20]:
# GPU training
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
    print("number of device",torch.cuda.device_count())
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [21]:
# Training the LSTM

#Monitor trianing results
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

#Helper function to convert data into tensor
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

# Accuracy recording in excel
workbook = openpyxl.Workbook()
worksheet = workbook.active



#Change parameters to reach better performances
for i in range(10):
    # record results.
    acc_repeat_train = ['acc_train']
    acc_repeat_val = ["acc_val"]
    loss_repeat_train = ["loss_repeat_train"]
    loss_repeat_val = ["loss_repeat_val"]

    # Repeat for three times in new model with different data. 
    for repeat in range(3):
        # Initialize the training model
        EMBEDDING_DIM = 60 
        HIDDEN_DIM = 10 + 20*i
        EPOCH=30
        model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
        model = model.to(device) # For GPU calculation
        loss_function = nn.NLLLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

        #Set data writer
        g = "./log/"+"Hidden_dim_"+str(EMBEDDING_DIM)+"_Embedding_dim_"+str(HIDDEN_DIM)+"repeat_"+str(repeat) 
        writer = SummaryWriter(g)

        #Shuffle data
        random.seed(repeat)
        training_data = copy.deepcopy(original_data) #Deep copy to make sure the original data is not influenced.
        random.shuffle(training_data)

        #Record accuracy


        for epoch in range(EPOCH):  
            print(epoch)
            training_acuracy = []
            validation_acuracy = []
            loss_list_train = []
            loss_list_validation = []
            for sentence, tags in tqdm(training_data[:5000]):
                # Clear Gradient because PyTorch accumulate it
                model.zero_grad()
                # Clear hidden state for LSTM, seperate them from the last instance.
                model.hidden = model.init_hidden()

                # 准备网络输入, 将其变为词索引的 Tensor 类型数据            
                sentence_in = prepare_sequence(sentence, word_to_ix)
                targets = prepare_sequence(tags, tag_to_ix)
                
                # Forward propagation
                tag_scores = model(sentence_in)

                # 第四步: 计算损失和梯度值, 通过调用 optimizer.step() 来更新梯度
                loss = loss_function(tag_scores, targets)
                a = (torch.argmax(tag_scores, dim=-1) == targets).sum().item() / len(targets)
                training_acuracy.append(a)
                loss.backward()
                optimizer.step()
                loss_list_train.append(loss.item())
            
            #Validation
            for sentence, tags in tqdm(training_data[5000:6000]):
                model.zero_grad()
                model.hidden = model.init_hidden()
                # 准备网络输入, 将其变为词索引的 Tensor 类型数据
                sentence_in = prepare_sequence(sentence, word_to_ix)
                targets = prepare_sequence(tags, tag_to_ix)
                # 第三步: 前向传播.
                tag_scores = model(sentence_in)
                #计算损失
                a = (torch.argmax(tag_scores, dim=-1) == targets).sum().item() / len(targets)
                validation_acuracy.append(a)
                loss = loss_function(tag_scores, targets)
                loss_list_validation.append(loss.item())
            
                
            #LSTM 可以做padding
            
            # Calculate error and accuracy
            # print(loss)
            loss_train= sum(loss_list_train)
            writer.add_scalar('Loss/training', loss_train, epoch)
            loss_val = sum(loss_list_validation)
            writer.add_scalar('Loss/validation', loss_val, epoch)
            #print(torch.argmax(tag_scores, dim=-1))
            # print("training_acuracy",sum(training_acuracy)/len(training_acuracy))
            acc_train = sum(training_acuracy)/len(training_acuracy)
            writer.add_scalar('Accuracy/training', acc_train, epoch)
            # print("validation_acuracy:",sum(validation_acuracy)/len(validation_acuracy))
            acc_val =sum(validation_acuracy)/len(validation_acuracy) 
            writer.add_scalar('Accuracy/validation', acc_val, epoch)

        acc_repeat_train.append(acc_train)
        acc_repeat_val.append(acc_val)
        loss_repeat_train.append(loss_train)
        loss_repeat_val.append(loss_val)         
            
        # 查看训练后的得分
        # with torch.no_grad():
        #     inputs = prepare_sequence(training_data[0][0], word_to_ix)
        #     tag_scores = model(inputs)

    worksheet.append(acc_repeat_train+acc_repeat_val+loss_repeat_train+loss_repeat_val)
writer.close()
workbook.save("Result.xlsx")

0


100%|██████████| 5000/5000 [00:28<00:00, 177.98it/s]
100%|██████████| 1000/1000 [00:01<00:00, 518.10it/s]


1


100%|██████████| 5000/5000 [00:46<00:00, 106.45it/s]
100%|██████████| 1000/1000 [00:03<00:00, 285.71it/s]


2


100%|██████████| 5000/5000 [00:52<00:00, 94.37it/s] 
100%|██████████| 1000/1000 [00:05<00:00, 170.42it/s]


3


100%|██████████| 5000/5000 [00:56<00:00, 87.99it/s] 
100%|██████████| 1000/1000 [00:03<00:00, 318.87it/s]


4


100%|██████████| 5000/5000 [00:57<00:00, 86.81it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 235.57it/s]


5


100%|██████████| 5000/5000 [00:55<00:00, 90.32it/s] 
100%|██████████| 1000/1000 [00:03<00:00, 259.05it/s]


6


100%|██████████| 5000/5000 [00:58<00:00, 85.47it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 229.55it/s]


7


100%|██████████| 5000/5000 [00:56<00:00, 87.82it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 221.46it/s]


8


100%|██████████| 5000/5000 [00:57<00:00, 86.48it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 228.94it/s]


9


100%|██████████| 5000/5000 [00:51<00:00, 96.81it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 245.96it/s]


10


100%|██████████| 5000/5000 [00:57<00:00, 87.19it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 207.87it/s]


11


100%|██████████| 5000/5000 [00:58<00:00, 84.98it/s] 
100%|██████████| 1000/1000 [00:04<00:00, 228.58it/s]


12


100%|██████████| 5000/5000 [00:32<00:00, 151.82it/s]
100%|██████████| 1000/1000 [00:01<00:00, 505.49it/s]


13


100%|██████████| 5000/5000 [00:26<00:00, 189.01it/s]
100%|██████████| 1000/1000 [00:02<00:00, 372.46it/s]


14


100%|██████████| 5000/5000 [00:29<00:00, 168.98it/s]
100%|██████████| 1000/1000 [00:03<00:00, 324.98it/s]


15


100%|██████████| 5000/5000 [00:37<00:00, 132.78it/s]
100%|██████████| 1000/1000 [00:01<00:00, 554.24it/s]


16


100%|██████████| 5000/5000 [00:24<00:00, 205.57it/s]
100%|██████████| 1000/1000 [00:01<00:00, 554.42it/s]


17


100%|██████████| 5000/5000 [00:50<00:00, 99.71it/s] 
100%|██████████| 1000/1000 [00:01<00:00, 631.16it/s]


18


100%|██████████| 5000/5000 [00:21<00:00, 230.87it/s]
100%|██████████| 1000/1000 [00:01<00:00, 544.67it/s]


19


100%|██████████| 5000/5000 [00:23<00:00, 217.18it/s]
100%|██████████| 1000/1000 [00:01<00:00, 574.08it/s]


20


100%|██████████| 5000/5000 [00:24<00:00, 202.39it/s]
100%|██████████| 1000/1000 [00:02<00:00, 472.91it/s]


21


100%|██████████| 5000/5000 [00:27<00:00, 180.00it/s]
100%|██████████| 1000/1000 [00:02<00:00, 442.94it/s]


22


100%|██████████| 5000/5000 [00:28<00:00, 175.09it/s]
100%|██████████| 1000/1000 [00:02<00:00, 451.96it/s]


23


100%|██████████| 5000/5000 [00:28<00:00, 176.14it/s]
100%|██████████| 1000/1000 [00:02<00:00, 458.74it/s]


24


100%|██████████| 5000/5000 [00:28<00:00, 174.21it/s]
100%|██████████| 1000/1000 [00:02<00:00, 446.28it/s]


25


100%|██████████| 5000/5000 [00:29<00:00, 170.69it/s]
100%|██████████| 1000/1000 [00:02<00:00, 444.20it/s]


26


100%|██████████| 5000/5000 [00:30<00:00, 165.16it/s]
100%|██████████| 1000/1000 [00:02<00:00, 383.34it/s]


27


100%|██████████| 5000/5000 [00:31<00:00, 160.86it/s]
100%|██████████| 1000/1000 [00:02<00:00, 443.51it/s]


28


100%|██████████| 5000/5000 [00:29<00:00, 167.14it/s]
100%|██████████| 1000/1000 [00:02<00:00, 417.78it/s]


29


100%|██████████| 5000/5000 [00:28<00:00, 172.81it/s]
100%|██████████| 1000/1000 [00:02<00:00, 439.98it/s]


0


100%|██████████| 5000/5000 [00:28<00:00, 174.66it/s]
100%|██████████| 1000/1000 [00:02<00:00, 443.60it/s]


1


100%|██████████| 5000/5000 [00:28<00:00, 175.74it/s]
100%|██████████| 1000/1000 [00:02<00:00, 429.54it/s]


2


100%|██████████| 5000/5000 [00:28<00:00, 176.22it/s]
100%|██████████| 1000/1000 [00:02<00:00, 435.24it/s]


3


100%|██████████| 5000/5000 [00:28<00:00, 176.61it/s]
100%|██████████| 1000/1000 [00:02<00:00, 458.80it/s]


4


100%|██████████| 5000/5000 [00:28<00:00, 175.78it/s]
100%|██████████| 1000/1000 [00:02<00:00, 444.18it/s]


5


100%|██████████| 5000/5000 [00:28<00:00, 175.69it/s]
100%|██████████| 1000/1000 [00:02<00:00, 448.54it/s]


6


100%|██████████| 5000/5000 [00:28<00:00, 177.18it/s]
100%|██████████| 1000/1000 [00:02<00:00, 466.53it/s]


7


100%|██████████| 5000/5000 [00:28<00:00, 176.58it/s]
100%|██████████| 1000/1000 [00:02<00:00, 452.08it/s]


8


100%|██████████| 5000/5000 [00:28<00:00, 176.03it/s]
100%|██████████| 1000/1000 [00:02<00:00, 435.11it/s]


9


100%|██████████| 5000/5000 [00:28<00:00, 177.01it/s]
100%|██████████| 1000/1000 [00:02<00:00, 436.50it/s]


10


100%|██████████| 5000/5000 [00:28<00:00, 177.60it/s]
100%|██████████| 1000/1000 [00:02<00:00, 473.22it/s]


11


100%|██████████| 5000/5000 [00:28<00:00, 176.69it/s]
100%|██████████| 1000/1000 [00:02<00:00, 437.40it/s]


12


100%|██████████| 5000/5000 [00:28<00:00, 177.44it/s]
100%|██████████| 1000/1000 [00:02<00:00, 441.01it/s]


13


100%|██████████| 5000/5000 [00:28<00:00, 177.47it/s]
100%|██████████| 1000/1000 [00:02<00:00, 440.66it/s]


14


100%|██████████| 5000/5000 [00:28<00:00, 174.47it/s]
100%|██████████| 1000/1000 [00:02<00:00, 423.98it/s]


15


100%|██████████| 5000/5000 [00:28<00:00, 173.14it/s]
100%|██████████| 1000/1000 [00:02<00:00, 443.03it/s]


16


100%|██████████| 5000/5000 [00:28<00:00, 175.69it/s]
100%|██████████| 1000/1000 [00:02<00:00, 450.80it/s]


17


100%|██████████| 5000/5000 [00:28<00:00, 176.28it/s]
100%|██████████| 1000/1000 [00:02<00:00, 432.14it/s]


18


100%|██████████| 5000/5000 [00:27<00:00, 180.45it/s]
100%|██████████| 1000/1000 [00:02<00:00, 455.98it/s]


19


100%|██████████| 5000/5000 [00:28<00:00, 176.85it/s]
100%|██████████| 1000/1000 [00:02<00:00, 442.89it/s]


20


100%|██████████| 5000/5000 [00:26<00:00, 188.92it/s]
100%|██████████| 1000/1000 [00:02<00:00, 470.44it/s]


21


100%|██████████| 5000/5000 [00:26<00:00, 189.93it/s]
100%|██████████| 1000/1000 [00:01<00:00, 508.65it/s]


22


100%|██████████| 5000/5000 [00:26<00:00, 191.29it/s]
100%|██████████| 1000/1000 [00:02<00:00, 493.40it/s]


23


100%|██████████| 5000/5000 [00:26<00:00, 190.82it/s]
100%|██████████| 1000/1000 [00:02<00:00, 460.27it/s]


24


100%|██████████| 5000/5000 [00:26<00:00, 188.02it/s]
100%|██████████| 1000/1000 [00:02<00:00, 472.12it/s]


25


100%|██████████| 5000/5000 [00:26<00:00, 189.68it/s]
100%|██████████| 1000/1000 [00:02<00:00, 462.16it/s]


26


100%|██████████| 5000/5000 [00:26<00:00, 187.26it/s]
100%|██████████| 1000/1000 [00:02<00:00, 480.02it/s]


27


100%|██████████| 5000/5000 [00:26<00:00, 190.49it/s]
100%|██████████| 1000/1000 [00:02<00:00, 456.95it/s]


28


100%|██████████| 5000/5000 [00:26<00:00, 189.97it/s]
100%|██████████| 1000/1000 [00:02<00:00, 464.67it/s]


29


100%|██████████| 5000/5000 [00:26<00:00, 187.80it/s]
100%|██████████| 1000/1000 [00:02<00:00, 474.56it/s]


0


100%|██████████| 5000/5000 [00:26<00:00, 187.92it/s]
100%|██████████| 1000/1000 [00:02<00:00, 498.07it/s]


1


100%|██████████| 5000/5000 [00:26<00:00, 189.85it/s]
100%|██████████| 1000/1000 [00:02<00:00, 486.51it/s]


2


100%|██████████| 5000/5000 [00:26<00:00, 190.65it/s]
100%|██████████| 1000/1000 [00:01<00:00, 510.83it/s]


3


100%|██████████| 5000/5000 [00:26<00:00, 188.21it/s]
100%|██████████| 1000/1000 [00:02<00:00, 460.21it/s]


4


100%|██████████| 5000/5000 [00:26<00:00, 187.50it/s]
100%|██████████| 1000/1000 [00:02<00:00, 496.71it/s]


5


100%|██████████| 5000/5000 [00:26<00:00, 188.78it/s]
100%|██████████| 1000/1000 [00:02<00:00, 455.58it/s]


6


100%|██████████| 5000/5000 [00:26<00:00, 191.35it/s]
100%|██████████| 1000/1000 [00:02<00:00, 476.81it/s]


7


100%|██████████| 5000/5000 [00:26<00:00, 189.92it/s]
100%|██████████| 1000/1000 [00:02<00:00, 476.30it/s]


8


100%|██████████| 5000/5000 [00:26<00:00, 187.68it/s]
100%|██████████| 1000/1000 [00:02<00:00, 489.47it/s]


9


100%|██████████| 5000/5000 [00:26<00:00, 188.64it/s]
100%|██████████| 1000/1000 [00:02<00:00, 488.40it/s]


10


100%|██████████| 5000/5000 [00:26<00:00, 189.70it/s]
100%|██████████| 1000/1000 [00:02<00:00, 463.76it/s]


11


100%|██████████| 5000/5000 [00:26<00:00, 190.12it/s]
100%|██████████| 1000/1000 [00:02<00:00, 494.64it/s]


12


100%|██████████| 5000/5000 [00:27<00:00, 184.72it/s]
100%|██████████| 1000/1000 [00:02<00:00, 491.45it/s]


13


100%|██████████| 5000/5000 [00:26<00:00, 189.21it/s]
100%|██████████| 1000/1000 [00:02<00:00, 490.26it/s]


14


100%|██████████| 5000/5000 [00:26<00:00, 189.11it/s]
100%|██████████| 1000/1000 [00:02<00:00, 496.65it/s]


15


100%|██████████| 5000/5000 [00:26<00:00, 188.36it/s]
100%|██████████| 1000/1000 [00:02<00:00, 476.14it/s]


16


100%|██████████| 5000/5000 [00:26<00:00, 188.80it/s]
100%|██████████| 1000/1000 [00:02<00:00, 462.31it/s]


17


100%|██████████| 5000/5000 [00:26<00:00, 189.11it/s]
100%|██████████| 1000/1000 [00:02<00:00, 482.72it/s]


18


100%|██████████| 5000/5000 [00:26<00:00, 191.18it/s]
100%|██████████| 1000/1000 [00:02<00:00, 480.58it/s]


19


100%|██████████| 5000/5000 [00:26<00:00, 185.63it/s]
100%|██████████| 1000/1000 [00:02<00:00, 480.07it/s]


20


100%|██████████| 5000/5000 [00:26<00:00, 191.56it/s]
100%|██████████| 1000/1000 [00:02<00:00, 461.49it/s]


21


100%|██████████| 5000/5000 [00:26<00:00, 185.97it/s]
100%|██████████| 1000/1000 [00:02<00:00, 463.58it/s]


22


100%|██████████| 5000/5000 [00:26<00:00, 188.07it/s]
100%|██████████| 1000/1000 [00:02<00:00, 463.71it/s]


23


100%|██████████| 5000/5000 [00:26<00:00, 188.50it/s]
100%|██████████| 1000/1000 [00:02<00:00, 486.40it/s]


24


100%|██████████| 5000/5000 [00:26<00:00, 188.35it/s]
100%|██████████| 1000/1000 [00:02<00:00, 493.80it/s]


25


100%|██████████| 5000/5000 [00:25<00:00, 198.69it/s]
100%|██████████| 1000/1000 [00:02<00:00, 492.37it/s]


26


100%|██████████| 5000/5000 [00:25<00:00, 195.93it/s]
100%|██████████| 1000/1000 [00:02<00:00, 493.52it/s]


27


100%|██████████| 5000/5000 [00:25<00:00, 196.21it/s]
100%|██████████| 1000/1000 [00:02<00:00, 473.36it/s]


28


100%|██████████| 5000/5000 [00:26<00:00, 191.95it/s]
100%|██████████| 1000/1000 [00:02<00:00, 497.66it/s]


29


100%|██████████| 5000/5000 [00:26<00:00, 192.27it/s]
100%|██████████| 1000/1000 [00:01<00:00, 509.06it/s]


TypeError: can only concatenate list (not "float") to list