In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
from torchtext.legacy.datasets import SequenceTaggingDataset
from torchtext.vocab import Vocab
from torchtext.legacy import data
from torchtext.legacy.data import Field
from torch.utils.data import DataLoader
#from TorchCRF import CRF
from torchcrf import CRF

In [34]:
# 定义模型类
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim * 2, tag_vocab_size)
        self.crf = CRF(tag_vocab_size)
    
    def forward(self, text):
        embeds = self.embedding(text)
        lstm_out, _ = self.bilstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        return tag_space
    
    # def loss(self, text, tags):
    #     emissions = self.forward(text)
    #     loss = self.crf(emissions, tags)
    #     return torch.mean(loss)
    def loss(self, text, tags):
        emissions = self.forward(text)
        mask = text.ne(0)
        loss =  -self.crf(emissions, tags, mask=mask)
        return torch.mean(loss)
    
    def decode(self, text):
        emissions = self.forward(text)
        return self.crf.decode(emissions)

In [35]:
# 加载数据集
TEXT = Field(lower=True, include_lengths=True, batch_first=True)
TAGS = Field(unk_token=None, batch_first=True)
train_data, val_data, test_data = SequenceTaggingDataset.splits(
    path='/data/wyf/InformationRetrievalProject/data/', train='eng_train.txt', validation='eng_testa.txt', test='eng_testb.txt',
    fields=(('text', TEXT), ('tags', TAGS)), separator=' ')
TEXT.build_vocab(train_data)
TAGS.build_vocab(train_data)

In [36]:
# 定义模型参数
vocab_size = len(TEXT.vocab)
tag_vocab_size = len(TAGS.vocab)
embedding_dim = 100
hidden_dim = 128
model = BiLSTM_CRF(vocab_size, tag_vocab_size, embedding_dim, hidden_dim)

In [37]:
# 迭代器
train_iterator = data.BucketIterator(train_data,batch_size=10,train=True,shuffle=True)
val_iterator = data.BucketIterator(val_data,batch_size=len(val_data),train=False,sort=False)
test_iterator = data.BucketIterator(test_data,batch_size=len(test_data),train=False,sort=False)

In [38]:
import time

In [39]:
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
num_epochs = 10
best_accuracy = 0.0

optimizer = optim.Adam(model.parameters())
train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
model.to(device)

BiLSTM_CRF(
  (embedding): Embedding(21012, 100)
  (bilstm): LSTM(100, 128, batch_first=True, bidirectional=True)
  (hidden2tag): Linear(in_features=256, out_features=47, bias=True)
  (crf): CRF(num_tags=47)
)

In [40]:
start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    accuracy = 0.0
    total_correct = 0.0
    total_data_num = len(train_iterator.dataset)
    steps = 0.0
    
    
    for batch in train_iterator:
        steps += 1
        
        text, text_lengths = batch.text
        tags = batch.tags
        text = text.to(device)
        tags = tags.to(device)
        
        optimizer.zero_grad()
        loss = model.loss(text, tags)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        if steps%100==0:
            print("Epoch %d_%.3f%%:  Training average Loss: %f"
                      %(epoch, steps * train_iterator.batch_size*100/len(train_iterator.dataset),total_loss/steps))
    
    print("Train done. Now evaluate.")
        
    model.eval()
    with torch.no_grad():
        total_loss = 0.0
        total_tokens = 0.0
        total_correct = 0.0
        steps = 0.0
        for batch in val_iterator:
            steps += 1
            text, text_lengths = batch.text
            tags = batch.tags
            text = text.to(device)
            tags = tags.to(device)
            emissions = model.forward(text)
            predicted_tags = model.decode(text)
            loss = -model.crf(emissions, tags)
            total_loss += loss.item()
            
            total_tokens += text_lengths.sum().item()
            #total_correct += sum([1 for i in range(len(predicted_tags)) for j in range(len(predicted_tags[i])) if predicted_tags[i][j] == tags[i][j]])
            for i in range(len(predicted_tags)):
                #for j in range(len(predicted_tags[i])):
                for j in range(text_lengths[i]):
                    if predicted_tags[i][j] == tags[j][i]:
                        total_correct += 1
            
        val_loss = total_loss / len(val_data)
        val_accuracy = total_correct / total_tokens
        print(f'Epoch {epoch} | Val Loss: {val_loss:.3f} | Val Accuracy: {val_accuracy:.3f}')

Epoch 0_6.672%:  Training average Loss: 395.550216
Epoch 0_13.345%:  Training average Loss: 303.157243
Epoch 0_20.017%:  Training average Loss: 257.670748
Epoch 0_26.690%:  Training average Loss: 230.556371
Epoch 0_33.362%:  Training average Loss: 209.575772
Epoch 0_40.035%:  Training average Loss: 193.945254
Epoch 0_46.707%:  Training average Loss: 181.733447
Epoch 0_53.380%:  Training average Loss: 171.855697
Epoch 0_60.052%:  Training average Loss: 163.659879
Epoch 0_66.724%:  Training average Loss: 156.957508
Epoch 0_73.397%:  Training average Loss: 150.740822
Epoch 0_80.069%:  Training average Loss: 145.227902
Epoch 0_86.742%:  Training average Loss: 140.876290
Epoch 0_93.414%:  Training average Loss: 136.400189
Train done. Now evaluate.
Epoch 0 | Val Loss: 8.390 | Val Accuracy: 0.001
Epoch 1_6.672%:  Training average Loss: 67.186603
Epoch 1_13.345%:  Training average Loss: 65.737945
Epoch 1_20.017%:  Training average Loss: 64.547185
Epoch 1_26.690%:  Training average Loss: 65.506

In [60]:
# 在测试集上评估模型
test_loader = DataLoader(test_data, batch_size=batch_size)
model.eval()
with torch.no_grad():
    total_loss = 0
    total_tokens = 0
    total_size = 0
    total_correct = 0
    for batch in test_iterator:
        text, text_lengths = batch.text
        tags = batch.tags
        text = text.to(device)
        tags = tags.to(device)
        emissions = model.forward(text)
        predicted_tags = model.decode(text)
        loss = -model.crf(emissions, tags)
        total_loss += loss.item()
        total_tokens += text_lengths.sum().item()
        
        #total_correct += sum([1 for i in range(len(predicted_tags)) for j in range(len(predicted_tags[i])) if predicted_tags[i][j] == tags[i][j]])
        for i in range(len(predicted_tags)):
            for j in range(len(predicted_tags[i])):
                total_size += 1
            #for j in range(text_lengths[i]):
                if predicted_tags[i][j] == tags[j][i]:
                    total_correct += 1
    test_loss = total_loss / len(test_data)
    #test_accuracy = total_correct / total_tokens
    test_accuracy = total_correct / total_size
    print(total_correct)
    print(total_tokens)
    print(total_size)
    print(f'Test Loss: {test_loss:.3f} | Test Accuracy: {test_accuracy:.3f}')

450968
46666
456816
Test Loss: 8.170 | Test Accuracy: 0.987
