BiLSTM序列标注---盲汉翻译

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random # for dataset shuffling
import openpyxl # for recording experimental value
import copy

https://blog.csdn.net/vivian_ll/article/details/93894151

In [5]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # LSTM以word_embeddings作为输入, 输出维度为 hidden_dim 的隐藏状态值
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # 线性层将隐藏状态空间映射到标注空间
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # 各个维度的含义是 (num_layers*num_directions, batch_size, hidden_dim)
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(
            embeds.view(len(sentence), 1, -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

In [6]:
#Making training data
original_data = []
t = open('./对应拼音标调.txt',"r", encoding='UTF8')
f = open("./训练集.txt", "rb")
#计数 数一共有多少行
lines = f.readlines()
tags = t.readlines()

# Training data: 6692(total number of sentences)*2(line&tag)
for index, line in enumerate(lines):
    line = lines[index].rstrip().split()
    tag = tags[index].rstrip().split()
    original_data.append((line, tag))

In [7]:
word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"0":0, "1":1, "2":2, "3":3, "4":4, ",":5, ".":6, "?":7,"!":8}

np.save('word_to_ix.npy',word_to_ix)
np.save('tag_to_ix.npy',tag_to_ix)

In [8]:
# GPU training
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
    print("number of device",torch.cuda.device_count())
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [9]:
# Training the LSTM

#Monitor trianing results
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

#Helper function to convert data into tensor
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

# Accuracy recording in excel
workbook = openpyxl.Workbook()
worksheet = workbook.active



#Change parameters to reach better performances
for i in range(10):
    # record results.
    acc_repeat_train = ['acc_train']
    acc_repeat_val = ["acc_val"]
    loss_repeat_train = ["loss_repeat_train"]
    loss_repeat_val = ["loss_repeat_val"]

    # Repeat for three times in new model with different data. 
    for repeat in range(3):
        #Set data writer
        g = "./log/"+"Hidden_dim_"+str(EMBEDDING_DIM)+"_Embedding_dim_"+str(HIDDEN_DIM)+"repeat_"+str(repeat) 
        writer = SummaryWriter(g)

        EMBEDDING_DIM = 60 
        HIDDEN_DIM = 10 + 20*i
        EPOCH=30
        model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
        model = model.to(device) # For GPU calculation
        loss_function = nn.NLLLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


        #Shuffle data
        random.seed(repeat)
        training_data = copy.deepcopy(original_data) #Deep copy to make sure the original data is not influenced.
        random.shuffle(training_data)

        #Record accuracy


        for epoch in range(EPOCH):  
            print(epoch)
            training_acuracy = []
            validation_acuracy = []
            loss_list_train = []
            loss_list_validation = []
            for sentence, tags in tqdm(training_data[:5000]):
                # Clear Gradient because PyTorch accumulate it
                model.zero_grad()
                # Clear hidden state for LSTM, seperate them from the last instance.
                model.hidden = model.init_hidden()

                # 准备网络输入, 将其变为词索引的 Tensor 类型数据            
                sentence_in = prepare_sequence(sentence, word_to_ix)
                targets = prepare_sequence(tags, tag_to_ix)
                
                # Forward propagation
                tag_scores = model(sentence_in)

                # 第四步: 计算损失和梯度值, 通过调用 optimizer.step() 来更新梯度
                loss = loss_function(tag_scores, targets)
                a = (torch.argmax(tag_scores, dim=-1) == targets).sum().item() / len(targets)
                training_acuracy.append(a)
                loss.backward()
                optimizer.step()
                loss_list_train.append(loss.item())
            
            #Validation
            for sentence, tags in tqdm(training_data[5000:6000]):
                model.zero_grad()
                model.hidden = model.init_hidden()
                # 准备网络输入, 将其变为词索引的 Tensor 类型数据
                sentence_in = prepare_sequence(sentence, word_to_ix)
                targets = prepare_sequence(tags, tag_to_ix)
                # 第三步: 前向传播.
                tag_scores = model(sentence_in)
                #计算损失
                a = (torch.argmax(tag_scores, dim=-1) == targets).sum().item() / len(targets)
                validation_acuracy.append(a)
                loss = loss_function(tag_scores, targets)
                loss_list_validation.append(loss.item())
            
                
            #LSTM 可以做padding
            
            # Calculate error and accuracy
            # print(loss)
            loss_train= sum(loss_list_train)
            writer.add_scalar('Loss/training', loss_train, epoch)
            loss_val = sum(loss_list_validation)
            writer.add_scalar('Loss/validation', loss_val, epoch)
            #print(torch.argmax(tag_scores, dim=-1))
            # print("training_acuracy",sum(training_acuracy)/len(training_acuracy))
            acc_train = sum(training_acuracy)/len(training_acuracy)
            writer.add_scalar('Accuracy/training', acc_train, epoch)
            # print("validation_acuracy:",sum(validation_acuracy)/len(validation_acuracy))
            acc_val =sum(validation_acuracy)/len(validation_acuracy) 
            writer.add_scalar('Accuracy/validation', acc_val, epoch)

        acc_repeat_train.append(acc_train)
        acc_repeat_val.append(acc_val)
        loss_repeat_train.append(loss_train)
        loss_repeat_val.append(loss_val)         
            
        # 查看训练后的得分
        # with torch.no_grad():
        #     inputs = prepare_sequence(training_data[0][0], word_to_ix)
        #     tag_scores = model(inputs)

    worksheet.append(acc_repeat_train+acc_repeat_val+loss_train+loss_val)
writer.close()
workbook.save("Result.xlsx")

0


100%|██████████| 5000/5000 [00:31<00:00, 160.95it/s]
100%|██████████| 1000/1000 [00:02<00:00, 481.88it/s]


1


100%|██████████| 5000/5000 [00:27<00:00, 184.25it/s]
100%|██████████| 1000/1000 [00:01<00:00, 507.27it/s]


2


100%|██████████| 5000/5000 [00:27<00:00, 183.08it/s]
100%|██████████| 1000/1000 [00:02<00:00, 488.47it/s]


3


100%|██████████| 5000/5000 [00:27<00:00, 182.86it/s]
100%|██████████| 1000/1000 [00:02<00:00, 489.40it/s]


4


100%|██████████| 5000/5000 [00:30<00:00, 165.50it/s]
100%|██████████| 1000/1000 [00:02<00:00, 414.55it/s]


5


100%|██████████| 5000/5000 [00:28<00:00, 173.23it/s]
100%|██████████| 1000/1000 [00:02<00:00, 487.90it/s]


6


100%|██████████| 5000/5000 [00:28<00:00, 174.23it/s]
100%|██████████| 1000/1000 [00:02<00:00, 454.01it/s]


7


100%|██████████| 5000/5000 [00:28<00:00, 175.80it/s]
100%|██████████| 1000/1000 [00:02<00:00, 467.48it/s]


8


100%|██████████| 5000/5000 [00:28<00:00, 174.09it/s]
100%|██████████| 1000/1000 [00:02<00:00, 447.14it/s]


9


100%|██████████| 5000/5000 [00:29<00:00, 168.92it/s]
100%|██████████| 1000/1000 [00:02<00:00, 430.44it/s]


10


100%|██████████| 5000/5000 [00:30<00:00, 165.90it/s]
100%|██████████| 1000/1000 [00:02<00:00, 433.05it/s]


11


100%|██████████| 5000/5000 [00:30<00:00, 166.22it/s]
100%|██████████| 1000/1000 [00:02<00:00, 425.13it/s]


12


100%|██████████| 5000/5000 [00:32<00:00, 154.52it/s]
100%|██████████| 1000/1000 [00:02<00:00, 432.49it/s]


13


100%|██████████| 5000/5000 [00:31<00:00, 160.50it/s]
100%|██████████| 1000/1000 [00:02<00:00, 380.24it/s]


14


100%|██████████| 5000/5000 [00:31<00:00, 159.33it/s]
100%|██████████| 1000/1000 [00:02<00:00, 410.82it/s]


15


100%|██████████| 5000/5000 [00:30<00:00, 163.26it/s]
100%|██████████| 1000/1000 [00:02<00:00, 427.60it/s]


16


100%|██████████| 5000/5000 [00:30<00:00, 162.42it/s]
100%|██████████| 1000/1000 [00:02<00:00, 419.11it/s]


17


100%|██████████| 5000/5000 [00:30<00:00, 161.40it/s]
100%|██████████| 1000/1000 [00:02<00:00, 436.58it/s]


18


100%|██████████| 5000/5000 [00:30<00:00, 161.74it/s]
100%|██████████| 1000/1000 [00:02<00:00, 433.57it/s]


19


100%|██████████| 5000/5000 [00:31<00:00, 157.52it/s]
100%|██████████| 1000/1000 [00:02<00:00, 434.86it/s]


20


100%|██████████| 5000/5000 [00:30<00:00, 164.11it/s]
100%|██████████| 1000/1000 [00:02<00:00, 411.19it/s]


21


100%|██████████| 5000/5000 [00:30<00:00, 162.28it/s]
100%|██████████| 1000/1000 [00:02<00:00, 420.67it/s]


22


100%|██████████| 5000/5000 [00:30<00:00, 163.62it/s]
100%|██████████| 1000/1000 [00:02<00:00, 422.67it/s]


23


100%|██████████| 5000/5000 [00:30<00:00, 161.68it/s]
100%|██████████| 1000/1000 [00:02<00:00, 435.11it/s]


24


100%|██████████| 5000/5000 [00:30<00:00, 163.14it/s]
100%|██████████| 1000/1000 [00:02<00:00, 424.83it/s]


25


100%|██████████| 5000/5000 [00:30<00:00, 161.76it/s]
100%|██████████| 1000/1000 [00:02<00:00, 432.81it/s]


26


100%|██████████| 5000/5000 [00:31<00:00, 159.17it/s]
100%|██████████| 1000/1000 [00:02<00:00, 424.71it/s]


27


100%|██████████| 5000/5000 [00:30<00:00, 162.24it/s]
100%|██████████| 1000/1000 [00:02<00:00, 430.50it/s]


28


100%|██████████| 5000/5000 [00:30<00:00, 162.27it/s]
100%|██████████| 1000/1000 [00:02<00:00, 418.14it/s]


29


100%|██████████| 5000/5000 [00:30<00:00, 164.04it/s]
100%|██████████| 1000/1000 [00:02<00:00, 436.15it/s]


0


100%|██████████| 5000/5000 [00:30<00:00, 163.56it/s]
100%|██████████| 1000/1000 [00:02<00:00, 408.19it/s]


1


100%|██████████| 5000/5000 [00:30<00:00, 164.17it/s]
100%|██████████| 1000/1000 [00:02<00:00, 413.33it/s]


2


100%|██████████| 5000/5000 [00:31<00:00, 160.28it/s]
100%|██████████| 1000/1000 [00:02<00:00, 421.53it/s]


3


100%|██████████| 5000/5000 [00:30<00:00, 161.85it/s]
100%|██████████| 1000/1000 [00:02<00:00, 432.44it/s]


4


100%|██████████| 5000/5000 [00:30<00:00, 163.25it/s]
100%|██████████| 1000/1000 [00:02<00:00, 413.12it/s]


5


100%|██████████| 5000/5000 [00:30<00:00, 163.89it/s]
100%|██████████| 1000/1000 [00:02<00:00, 426.13it/s]


6


100%|██████████| 5000/5000 [00:31<00:00, 160.85it/s]
100%|██████████| 1000/1000 [00:02<00:00, 396.81it/s]


7


100%|██████████| 5000/5000 [00:28<00:00, 175.47it/s]
100%|██████████| 1000/1000 [00:04<00:00, 248.71it/s]


8


100%|██████████| 5000/5000 [05:58<00:00, 13.93it/s] 
100%|██████████| 1000/1000 [00:02<00:00, 493.91it/s]


9


 15%|█▌        | 755/5000 [00:03<00:22, 188.39it/s]

: 