# 基于BERT Fine-Tune的语义角色标注模型

## Part I. 模型与数据集构建

### 导入相关库

In [4]:
import torch

import torch.utils.data as torchData

import time

import numpy as np

### 加载pretrained bert model
##### def seq2ids(seq) ->  input tensor of model

In [5]:
import pytorch_pretrained_bert as bert

bert_model_dir = "bert_model/bert-chinese/"

tokenizer = bert.BertTokenizer.from_pretrained(bert_model_dir)

bert = bert.BertModel.from_pretrained(bert_model_dir )

def seq2ids(seq):

    ids = torch.tensor([tokenizer.convert_tokens_to_ids(seq)])

    return ids

### 构建我们的模型

#### 模型在bert上方加入一个全连接层使得bert的输出能够转化为对所有语义标签的输出概率

In [6]:
class SemanticRoleLabelModel(torch.nn.Module):

    def __init__(self, n_hidden,labelnum):

        super().__init__()

        self.bert = bert


        self.output = torch.nn.Sequential(

            torch.nn.Linear(768 , n_hidden),

            torch.nn.Tanh(),

            torch.nn.Dropout(0.2),

            torch.nn.Linear(n_hidden,labelnum),

            torch.nn.Tanh()
        )

        self.prob = torch.nn.Softmax(dim = 1)


    def forward(self, b_seqs):

        probs = []

        for seqs in b_seqs:

            layers = self.bert(seqs, output_all_encoded_layers= False)[0]

            #lastFourHiddenlayers = layers[-4:]

            #concat = torch.cat(lastFourHiddenlayers, dim=2)[0]

            output = self.output(layers[0])

            prob = self.prob(output)

            probs.append(prob)

        return probs

### 构建数据集

#### 我们使用845篇文章的前600篇作为训练集 后245篇作为测试集

In [10]:
from utils import load_seq, split,Extract_Information_By_Re as E

class MyDataSet(torchData.Dataset):

    def __init__(self,train_or_test = "train"):

        self.seqs = []

        self.labels = []

        if train_or_test == "train":
            
            for i in range(600):

                seqs, labels = split(*load_seq(i, withlabel=True))

                self.seqs += seqs

                self.labels += labels

        elif train_or_test == "test":
            
            for i in range(600,845):

                seqs, labels = split(*load_seq(i, withlabel=True))

                self.seqs += seqs

                self.labels += labels
                
    def __len__(self):

        return len(self.seqs)

    def __getitem__(self, index):

        x = seq2ids(self.seqs[index]).cuda()

        y = E.labels2index(self.labels[index])

        return x,y

### 实例化我们的模型，优化器，损失函数 与 数据集
 tip: 
①我们的模型通过GPU计算
②损失函数:交叉熵
③优化器:Adam

In [11]:
Model = SemanticRoleLabelModel(300, len(E.labels))

Model = torch.nn.DataParallel(Model).cuda()

optimizer = torch.optim.Adam([{'params': Model.parameters()}, ], lr=1e-2)

loss = torch.nn.CrossEntropyLoss()

trainingSet = MyDataSet(train_or_test = "train")


## Part II. 训练模型

In [23]:
dataGenerator = torchData.DataLoader(dataset = trainingSet,batch_size = 10,shuffle = True,collate_fn = lambda data : data)

for epoch in range(10):

    epoch_loss = [ ]

    start_time = time.time()

    used_seq = 0
    
    for data in dataGenerator:

        print(data)
        b_seq = [item[0] for item in data]

        b_label = [item[1] for item in data]
        
        for seq in b_seq:
            
            print(seq)
            
        time_usage = time.time() - start_time
        
        used_seq += len(b_seq)
        
        process = str(round( used_seq / len(trainingSet) * 100 , 6)) + "%"
        
        print("\r epoch : {},process : {} ,timeUsage:{}".format(epoch,process,time_usage),end = "",flush=True)

        b_prob = Model(b_seq)

        b_target = [torch.tensor(label).cuda() for label in b_label]

        b_loss = torch.tensor([0.0]).cuda()

        for prob, target in zip(b_prob, b_target):

            sentence_loss = loss(prob,target)

            b_loss += sentence_loss

        b_loss /= len(b_seq)

        epoch_loss.append(b_loss.item())

        optimizer.zero_grad()

        b_loss.backward()

        optimizer.step()

    print("\repoch {} : Loss :{} ,timeUsage:{}".format(epoch, sum(epoch_loss)/len(epoch_loss), time_usage))


RuntimeError: CUDA error: device-side assert triggered